Skip to content

Commit

Permalink
Initialization: Fixes the SDK to retry if the initialization fails (#…
Browse files Browse the repository at this point in the history
…3027)

The SDK has an initialization task to get the account information and other info. If this task fails with a DocumentClientException it was never being recreated. This is a problem because if a 408 is thrown or some other transient problem the SDK will always returned the cached failure and will not actually retry the request.

Solution:
A new initialization function factory was created. This allows the initialization task to be recreated. The EnsureValidClientAsync now calls a new method that is thread safe to get or create a new task if the existing one failed.
  • Loading branch information
j82w authored Feb 14, 2022
1 parent d7b9546 commit 83c2c73
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 74 deletions.
106 changes: 46 additions & 60 deletions Microsoft.Azure.Cosmos/src/DocumentClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ namespace Microsoft.Azure.Cosmos
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Runtime.CompilerServices;
using System.Security;
using System.Text;
using System.Threading;
Expand Down Expand Up @@ -109,6 +107,7 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
private const int DefaultRntbdReceiveHangDetectionTimeSeconds = 65;
private const int DefaultRntbdSendHangDetectionTimeSeconds = 10;
private const bool DefaultEnableCpuMonitor = true;
private const string DefaultInitTaskKey = "InitTaskKey";

//Auth
private readonly AuthorizationTokenProvider cosmosAuthorization;
Expand Down Expand Up @@ -144,7 +143,6 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
//Private state.
private bool isSuccessfullyInitialized;
private bool isDisposed;
private object initializationSyncLock; // guards initializeTask

// creator of TransportClient is responsible for disposing it.
private IStoreClientFactory storeClientFactory;
Expand All @@ -166,7 +164,8 @@ internal partial class DocumentClient : IDisposable, IAuthorizationTokenProvider
private AsyncLazy<QueryPartitionProvider> queryPartitionProvider;

private DocumentClientEventSource eventSource;
internal Task initializeTask;
private Func<Task<bool>> initializeTaskFactory;
internal AsyncCacheNonBlocking<string, bool> initTaskCache = new AsyncCacheNonBlocking<string, bool>();

private JsonSerializerSettings serializerSettings;
private event EventHandler<SendingRequestEventArgs> sendingRequest;
Expand Down Expand Up @@ -914,28 +913,35 @@ internal virtual void Initialize(Uri serviceEndpoint,
// Setup the proxy to be used based on connection mode.
// For gateway: GatewayProxy.
// For direct: WFStoreProxy [set in OpenAsync()].
this.initializationSyncLock = new object();

this.eventSource = DocumentClientEventSource.Instance;

this.initializeTask = TaskHelper.InlineIfPossibleAsync(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));

// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
this.initializeTaskFactory = () =>
{
Task<bool> task = TaskHelper.InlineIfPossible<bool>(
() => this.GetInitializationTaskAsync(storeClientFactory: storeClientFactory),
new ResourceThrottleRetryPolicy(
this.ConnectionPolicy.RetryOptions.MaxRetryAttemptsOnThrottledRequests,
this.ConnectionPolicy.RetryOptions.MaxRetryWaitTimeInSeconds));
// ContinueWith on the initialization task is needed for handling the UnobservedTaskException
// if this task throws for some reason. Awaiting inside a constructor is not supported and
// even if we had to await inside GetInitializationTask to catch the exception, that will
// be a blocking call. In such cases, the recommended approach is to "handle" the
// UnobservedTaskException by using ContinueWith method w/ TaskContinuationOptions.OnlyOnFaulted
// and accessing the Exception property on the target task.
#pragma warning disable VSTHRD110 // Observe result of async calls
this.initializeTask.ContinueWith(t =>
task.ContinueWith(t => DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception), TaskContinuationOptions.OnlyOnFaulted);
#pragma warning restore VSTHRD110 // Observe result of async calls
{
DefaultTrace.TraceWarning("initializeTask failed {0}", t.Exception);
}, TaskContinuationOptions.OnlyOnFaulted);
return task;
};

// Create the task to start the initialize task
// Task will be awaited on in the EnsureValidClientAsync
Task t = this.initTaskCache.GetAsync(
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: false,
callBackOnForceRefresh: null);

this.traceId = Interlocked.Increment(ref DocumentClient.idCounter);
DefaultTrace.TraceInformation(string.Format(
Expand All @@ -951,7 +957,7 @@ internal virtual void Initialize(Uri serviceEndpoint,
}

// Always called from under the lock except when called from Intilialize method during construction.
private async Task GetInitializationTaskAsync(IStoreClientFactory storeClientFactory)
private async Task<bool> GetInitializationTaskAsync(IStoreClientFactory storeClientFactory)
{
await this.InitializeGatewayConfigurationReaderAsync();

Expand Down Expand Up @@ -984,6 +990,8 @@ private async Task GetInitializationTaskAsync(IStoreClientFactory storeClientFac
{
this.InitializeDirectConnectivity(storeClientFactory);
}

return true;
}

private async Task InitializeCachesAsync(string databaseName, DocumentCollection collection, CancellationToken cancellationToken)
Expand Down Expand Up @@ -1243,6 +1251,12 @@ public void Dispose()
this.queryPartitionProvider.Value.Dispose();
}

if (this.initTaskCache != null)
{
this.initTaskCache.Dispose();
this.initTaskCache = null;
}

DefaultTrace.TraceInformation("DocumentClient with id {0} disposed.", this.traceId);
DefaultTrace.Flush();

Expand Down Expand Up @@ -1431,18 +1445,13 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
// client which is unusable and can resume working if it failed initialization once.
// If we have to reinitialize the client, it needs to happen in thread safe manner so that
// we dont re-initalize the task again for each incoming call.
Task initTask = null;

lock (this.initializationSyncLock)
{
initTask = this.initializeTask;
}

try
{
await initTask;
this.isSuccessfullyInitialized = true;
return;
this.isSuccessfullyInitialized = await this.initTaskCache.GetAsync(
key: DocumentClient.DefaultInitTaskKey,
singleValueInitFunc: this.initializeTaskFactory,
forceRefresh: false,
callBackOnForceRefresh: null);
}
catch (DocumentClientException ex)
{
Expand All @@ -1452,33 +1461,10 @@ internal virtual async Task EnsureValidClientAsync(ITrace trace)
}
catch (Exception e)
{
DefaultTrace.TraceWarning("initializeTask failed {0}", e.ToString());
childTrace.AddDatum("initializeTask failed", e.ToString());
}

lock (this.initializationSyncLock)
{
// if the task has not been updated by another caller, update it
if (object.ReferenceEquals(this.initializeTask, initTask))
{
this.initializeTask = this.GetInitializationTaskAsync(storeClientFactory: null);
}

initTask = this.initializeTask;
}

try
{
await initTask;
this.isSuccessfullyInitialized = true;
}
catch (DocumentClientException ex)
{
throw Resource.CosmosExceptions.CosmosExceptionFactory.Create(
dce: ex,
trace: trace);
DefaultTrace.TraceWarning("initializeTask failed {0}", e);
childTrace.AddDatum("initializeTask failed", e);
throw;
}

}
}

Expand Down Expand Up @@ -6722,7 +6708,7 @@ private JsonSerializerSettings GetSerializerSettingsForRequest(Documents.Client.
private INameValueCollection GetRequestHeaders(Documents.Client.RequestOptions options)
{
Debug.Assert(
this.initializeTask.IsCompleted,
this.isSuccessfullyInitialized,
"GetRequestHeaders should be called after initialization task has been awaited to avoid blocking while accessing ConsistencyLevel property");

INameValueCollection headers = new StoreRequestNameValueCollection();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Formats.Asn1;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.NetworkInformation;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Query.Core;
using Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests;
Expand All @@ -28,6 +31,124 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
public class ClientTests
{
[TestMethod]
public async Task ValidateExceptionOnInitTask()
{
int httpCallCount = 0;
HttpClientHandlerHelper httpClientHandlerHelper = new HttpClientHandlerHelper()
{
RequestCallBack = (request, cancellToken) =>
{
Interlocked.Increment(ref httpCallCount);
return null;
}
};

using CosmosClient cosmosClient = new CosmosClient(
accountEndpoint: "https://localhost:8081",
authKeyOrResourceToken: Convert.ToBase64String(Encoding.UTF8.GetBytes(Guid.NewGuid().ToString())),
clientOptions: new CosmosClientOptions()
{
HttpClientFactory = () => new HttpClient(httpClientHandlerHelper),
});

CosmosException cosmosException1 = null;
try
{
await cosmosClient.GetContainer("db", "c").ReadItemAsync<JObject>("Random", new Cosmos.PartitionKey("DoesNotExist"));
}
catch (CosmosException ex)
{
cosmosException1 = ex;
Assert.IsTrue(httpCallCount > 0);
}

httpCallCount = 0;
try
{
await cosmosClient.GetContainer("db", "c").ReadItemAsync<JObject>("Random2", new Cosmos.PartitionKey("DoesNotExist2"));
}
catch (CosmosException ex)
{
Assert.IsFalse(object.ReferenceEquals(ex, cosmosException1));
Assert.IsTrue(httpCallCount > 0);
}
}

[TestMethod]
public async Task InitTaskThreadSafe()
{
int httpCallCount = 0;
bool delayCallBack = true;
HttpClientHandlerHelper httpClientHandlerHelper = new HttpClientHandlerHelper()
{
RequestCallBack = async (request, cancellToken) =>
{
Interlocked.Increment(ref httpCallCount);
while (delayCallBack)
{
await Task.Delay(TimeSpan.FromMilliseconds(100));
}
return null;
}
};

using CosmosClient cosmosClient = new CosmosClient(
accountEndpoint: "https://localhost:8081",
authKeyOrResourceToken: Convert.ToBase64String(Encoding.UTF8.GetBytes(Guid.NewGuid().ToString())),
clientOptions: new CosmosClientOptions()
{
HttpClientFactory = () => new HttpClient(httpClientHandlerHelper),
});

List<Task> tasks = new List<Task>();

Container container = cosmosClient.GetContainer("db", "c");

for(int loop = 0; loop < 3; loop++)
{
for (int i = 0; i < 10; i++)
{
tasks.Add(this.ReadNotFound(container));
}

Stopwatch sw = Stopwatch.StartNew();
while(this.TaskStartedCount < 10 && sw.Elapsed.TotalSeconds < 2)
{
await Task.Delay(TimeSpan.FromMilliseconds(50));
}

Assert.AreEqual(10, this.TaskStartedCount, "Tasks did not start");
delayCallBack = false;

await Task.WhenAll(tasks);

Assert.AreEqual(1, httpCallCount, "Only the first task should do the http call. All other should wait on the first task");

// Reset counters and retry the client to verify a new http call is done for new requests
tasks.Clear();
delayCallBack = true;
this.TaskStartedCount = 0;
httpCallCount = 0;
}
}

private int TaskStartedCount = 0;

private async Task<Exception> ReadNotFound(Container container)
{
try
{
Interlocked.Increment(ref this.TaskStartedCount);
await container.ReadItemAsync<JObject>("Random", new Cosmos.PartitionKey("DoesNotExist"));
throw new Exception("Should throw a CosmosException 403");
}
catch (CosmosException ex)
{
return ex;
}
}

public async Task ResourceResponseStreamingTest()
{
using (DocumentClient client = TestCommon.CreateClient(true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public async Task CosmosHttpClientRetryValidation()
}
catch (CosmosException rte)
{
Assert.IsTrue(handler.Count >= 6);
Assert.IsTrue(handler.Count >= 3, $"HandlerCount: {handler.Count}; Expecte 6");
string message = rte.ToString();
Assert.IsTrue(message.Contains("Start Time"), "Start Time:" + message);
Assert.IsTrue(message.Contains("Total Duration"), "Total Duration:" + message);
Expand All @@ -129,9 +129,10 @@ private class TransientHttpClientCreatorHandler : DelegatingHandler

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if (this.Count++ <= 3)
this.Count++;
if (this.Count < 3)
{
throw new WebException();
throw new WebException($"Mocked WebException {this.Count}");
}

throw new TaskCanceledException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace Microsoft.Azure.Cosmos.SDK.EmulatorTests
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Tracing;
using Microsoft.Azure.Documents;

//Internal Test hooks.
Expand All @@ -21,7 +22,7 @@ internal static class DocumentClientExtensions
//This will lock the client instance to a particular replica Index.
public static void LockClient(this DocumentClient client, uint replicaIndex)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
ServerStoreModel serverStoreModel = (client.StoreModel as ServerStoreModel);
if (serverStoreModel != null)
{
Expand All @@ -31,7 +32,7 @@ public static void LockClient(this DocumentClient client, uint replicaIndex)

public static void ForceAddressRefresh(this DocumentClient client, bool forceAddressRefresh)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
ServerStoreModel serverStoreModel = (client.StoreModel as ServerStoreModel);
if (serverStoreModel != null)
{
Expand All @@ -42,7 +43,7 @@ public static void ForceAddressRefresh(this DocumentClient client, bool forceAdd
//Returns the address of replica.
public static string GetAddress(this DocumentClient client)
{
client.initializeTask.Wait();
client.EnsureValidClientAsync(NoOpTrace.Singleton).Wait();
return (client.StoreModel as ServerStoreModel).LastReadAddress;
}
}
Expand Down
Loading

0 comments on commit 83c2c73

Please sign in to comment.