-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CSHARP-5017: Retry KMS requests on transient errors #1541
base: main
Are you sure you want to change the base?
Changes from all commits
6791f8c
477d618
23a9930
b298ec7
f0f1253
4211941
82bbd98
b24a6c6
e8a92f4
25fc1a5
17fc930
d83cafa
cfed997
9023982
5e39fbd
7835423
dcaf30f
955be75
c322e46
9af1b54
0761f1a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider) | |
|
||
private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken) | ||
{ | ||
var requests = context.GetKmsMessageRequests(); | ||
foreach (var request in requests) | ||
while (context.GetNextKmsMessageRequest() is { } request) | ||
{ | ||
SendKmsRequest(request, cancellationToken); | ||
} | ||
requests.MarkDone(); | ||
context.MarkKmsDone(); | ||
} | ||
|
||
private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken) | ||
{ | ||
var requests = context.GetKmsMessageRequests(); | ||
foreach (var request in requests) | ||
while (context.GetNextKmsMessageRequest() is { } request) | ||
{ | ||
await SendKmsRequestAsync(request, cancellationToken).ConfigureAwait(false); | ||
} | ||
requests.MarkDone(); | ||
context.MarkKmsDone(); | ||
} | ||
|
||
private void ProcessNeedMongoKeysState(CryptContext context, CancellationToken cancellationToken) | ||
|
@@ -278,48 +276,90 @@ private static byte[] ProcessReadyState(CryptContext context) | |
|
||
private void SendKmsRequest(KmsRequest request, CancellationToken cancellation) | ||
{ | ||
var endpoint = CreateKmsEndPoint(request.Endpoint); | ||
|
||
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); | ||
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); | ||
using (var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation)) | ||
using (var binary = request.GetMessage()) | ||
try | ||
{ | ||
var endpoint = CreateKmsEndPoint(request.Endpoint); | ||
|
||
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); | ||
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); | ||
using var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation); | ||
|
||
var sleepMs = request.Sleep; | ||
if (sleepMs > 0) | ||
{ | ||
Thread.Sleep(sleepMs); | ||
} | ||
|
||
using var binary = request.GetMessage(); | ||
var requestBytes = binary.ToArray(); | ||
sslStream.Write(requestBytes, 0, requestBytes.Length); | ||
|
||
while (request.BytesNeeded > 0) | ||
{ | ||
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive. | ||
var count = sslStream.Read(buffer, 0, buffer.Length); | ||
|
||
if (count == 0) | ||
{ | ||
throw new IOException("Unexpected end of stream. No data was read from the SSL stream."); | ||
} | ||
|
||
var responseBytes = new byte[count]; | ||
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count); | ||
request.Feed(responseBytes); | ||
} | ||
} | ||
catch (Exception ex) when (ex is IOException or SocketException) | ||
{ | ||
if (!request.Fail()) | ||
{ | ||
throw; | ||
} | ||
} | ||
} | ||
|
||
private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation) | ||
{ | ||
var endpoint = CreateKmsEndPoint(request.Endpoint); | ||
|
||
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); | ||
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); | ||
using (var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false)) | ||
using (var binary = request.GetMessage()) | ||
try | ||
{ | ||
var endpoint = CreateKmsEndPoint(request.Endpoint); | ||
|
||
var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider); | ||
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory); | ||
using var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false); | ||
|
||
var sleepMs = request.Sleep; | ||
if (sleepMs > 0) | ||
{ | ||
await Task.Delay(sleepMs, cancellation).ConfigureAwait(false); | ||
} | ||
|
||
using var binary = request.GetMessage(); | ||
var requestBytes = binary.ToArray(); | ||
await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false); | ||
|
||
while (request.BytesNeeded > 0) | ||
{ | ||
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive. | ||
var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false); | ||
|
||
if (count == 0) | ||
{ | ||
throw new IOException("Unexpected end of stream. No data was read from the SSL stream."); | ||
} | ||
|
||
var responseBytes = new byte[count]; | ||
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count); | ||
request.Feed(responseBytes); | ||
} | ||
} | ||
catch (Exception ex) when (ex is IOException or SocketException) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure here which other kind of network-errors related exceptions we could get There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw some comment on the Drivers ticket for this feature that it only pertains to HTTP errors so for socket-level errors do we also retry on those? Other comments in the ticket seems to suggest we'll need to tell libmongocrypt to reset its state if we encounter socket errors while reading a kms response. Actually are http errors thrown as socket or IOException errors? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think that comment is a little bit misleading. So, as far as I have understood the situation is like this:
|
||
{ | ||
if (!request.Fail()) | ||
{ | ||
throw; | ||
} | ||
} | ||
} | ||
|
||
// nested type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,6 +147,9 @@ static Library() | |
_mongocrypt_ctx_setopt_query_type = new Lazy<Delegates.mongocrypt_ctx_setopt_query_type>( | ||
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_setopt_query_type>( | ||
("mongocrypt_ctx_setopt_query_type")), true); | ||
_mongocrypt_setopt_retry_kms = new Lazy<Delegates.mongocrypt_setopt_retry_kms>( | ||
() => __loader.Value.GetFunction<Delegates.mongocrypt_setopt_retry_kms>( | ||
("mongocrypt_setopt_retry_kms")), true); | ||
|
||
_mongocrypt_ctx_status = new Lazy<Delegates.mongocrypt_ctx_status>( | ||
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_status>(("mongocrypt_ctx_status")), true); | ||
|
@@ -210,6 +213,11 @@ static Library() | |
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_destroy>(("mongocrypt_ctx_destroy")), true); | ||
_mongocrypt_kms_ctx_get_kms_provider = new Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider>( | ||
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_get_kms_provider>(("mongocrypt_kms_ctx_get_kms_provider")), true); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For these and the following methods/variables, is there a specific ordering? It does not seem to be completely alphabetical so for now I've put the new things on the bottom There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There isn't a specific order for the methods, but many of them are grouped together. For instance, you'll notice that most of the setopt methods are clustered. For now, I recommend placing the mongocrypt_setopt_retry_kms method alongside the other setopt methods and leaving the other new methods at the bottom. You can add a TODO comment above the static constructor to organize the methods later (so you don't have unrelated changes to your PR). I'll handle that when addressing technical debt in Libmongocrypt later this quarter. |
||
_mongocrypt_kms_ctx_usleep = new Lazy<Delegates.mongocrypt_kms_ctx_usleep>( | ||
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_usleep>(("mongocrypt_kms_ctx_usleep")), true); | ||
_mongocrypt_kms_ctx_fail = new Lazy<Delegates.mongocrypt_kms_ctx_fail>( | ||
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_fail>(("mongocrypt_kms_ctx_fail")), true); | ||
} | ||
|
||
/// <summary> | ||
|
@@ -287,6 +295,7 @@ public static string Version | |
internal static Delegates.mongocrypt_ctx_setopt_algorithm_range mongocrypt_ctx_setopt_algorithm_range => _mongocrypt_ctx_setopt_algorithm_range.Value; | ||
internal static Delegates.mongocrypt_ctx_setopt_contention_factor mongocrypt_ctx_setopt_contention_factor => _mongocrypt_ctx_setopt_contention_factor.Value; | ||
internal static Delegates.mongocrypt_ctx_setopt_query_type mongocrypt_ctx_setopt_query_type => _mongocrypt_ctx_setopt_query_type.Value; | ||
internal static Delegates.mongocrypt_setopt_retry_kms mongocrypt_setopt_retry_kms => _mongocrypt_setopt_retry_kms.Value; | ||
|
||
internal static Delegates.mongocrypt_ctx_state mongocrypt_ctx_state => _mongocrypt_ctx_state.Value; | ||
internal static Delegates.mongocrypt_ctx_mongo_op mongocrypt_ctx_mongo_op => _mongocrypt_ctx_mongo_op.Value; | ||
|
@@ -305,6 +314,9 @@ public static string Version | |
internal static Delegates.mongocrypt_ctx_destroy mongocrypt_ctx_destroy => _mongocrypt_ctx_destroy.Value; | ||
internal static Delegates.mongocrypt_kms_ctx_get_kms_provider mongocrypt_kms_ctx_get_kms_provider => _mongocrypt_kms_ctx_get_kms_provider.Value; | ||
|
||
internal static Delegates.mongocrypt_kms_ctx_usleep mongocrypt_kms_ctx_usleep => _mongocrypt_kms_ctx_usleep.Value; | ||
internal static Delegates.mongocrypt_kms_ctx_fail mongocrypt_kms_ctx_fail => _mongocrypt_kms_ctx_fail.Value; | ||
|
||
private static readonly Lazy<LibraryLoader> __loader = new Lazy<LibraryLoader>( | ||
() => new LibraryLoader(), true); | ||
private static readonly Lazy<Delegates.mongocrypt_version> _mongocrypt_version; | ||
|
@@ -392,6 +404,10 @@ public static string Version | |
private static readonly Lazy<Delegates.mongocrypt_ctx_destroy> _mongocrypt_ctx_destroy; | ||
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider> _mongocrypt_kms_ctx_get_kms_provider; | ||
|
||
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_usleep> _mongocrypt_kms_ctx_usleep; | ||
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_fail> _mongocrypt_kms_ctx_fail; | ||
private static readonly Lazy<Delegates.mongocrypt_setopt_retry_kms> _mongocrypt_setopt_retry_kms; | ||
|
||
// nested types | ||
internal enum StatusType | ||
{ | ||
|
@@ -640,6 +656,9 @@ public delegate bool | |
[return: MarshalAs(UnmanagedType.I1)] | ||
public delegate bool mongocrypt_ctx_setopt_query_type(ContextSafeHandle ctx, [MarshalAs(UnmanagedType.LPStr)] string query_type, int length); | ||
|
||
[return: MarshalAs(UnmanagedType.I1)] | ||
public delegate bool mongocrypt_setopt_retry_kms(MongoCryptSafeHandle handle, bool enable); | ||
|
||
public delegate CryptContext.StateCode mongocrypt_ctx_state(ContextSafeHandle handle); | ||
|
||
[return: MarshalAs(UnmanagedType.I1)] | ||
|
@@ -681,6 +700,11 @@ public delegate bool | |
|
||
public delegate void mongocrypt_ctx_destroy(IntPtr ptr); | ||
public delegate IntPtr mongocrypt_kms_ctx_get_kms_provider(IntPtr handle, out uint length); | ||
|
||
public delegate long mongocrypt_kms_ctx_usleep(IntPtr handle); | ||
|
||
[return: MarshalAs(UnmanagedType.I1)] | ||
public delegate bool mongocrypt_kms_ctx_fail(IntPtr handle); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)