Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSHARP-5017: Retry KMS requests on transient errors #1541

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions evergreen/evergreen.yml
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,15 @@ functions:
cd ${DRIVERS_TOOLS}/.evergreen/csfle
. ./activate-kmstlsvenv.sh
python -u kms_http_server.py -v --ca_file ../x509gen/ca.pem --cert_file ../x509gen/server.pem --port 8002 --require_client_cert
- command: shell.exec
params:
background: true
shell: "bash"
script: |
#server.pem client cert
cd ${DRIVERS_TOOLS}/.evergreen/csfle
. ./activate-kmstlsvenv.sh
python -u kms_failpoint_server.py --port 9003

start-kms-mock-kmip-server:
- command: shell.exec
Expand Down
2 changes: 2 additions & 0 deletions src/MongoDB.Driver.Encryption/CryptClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ public static CryptClient Create(CryptOptions options)

Library.mongocrypt_setopt_use_need_kms_credentials_state(handle);

Library.mongocrypt_setopt_retry_kms(handle, true);

Library.mongocrypt_init(handle);

if (options.IsCryptSharedLibRequired)
Expand Down
14 changes: 4 additions & 10 deletions src/MongoDB.Driver.Encryption/CryptContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,12 @@ public Binary FinalizeForEncryption()
}

/// <summary>
/// Gets a collection of KMS message requests to make
/// Gets the next KMS message request
/// </summary>
/// <returns>Collection of KMS Messages</returns>
public KmsRequestCollection GetKmsMessageRequests()
public KmsRequest GetNextKmsMessageRequest()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

{
var requests = new List<KmsRequest>();
for (IntPtr request = Library.mongocrypt_ctx_next_kms_ctx(_handle); request != IntPtr.Zero; request = Library.mongocrypt_ctx_next_kms_ctx(_handle))
{
requests.Add(new KmsRequest(request));
}

return new KmsRequestCollection(requests, this);
var request = Library.mongocrypt_ctx_next_kms_ctx(_handle);
return request == IntPtr.Zero ? null : new KmsRequest(request);
}

/// <summary>
Expand Down
11 changes: 11 additions & 0 deletions src/MongoDB.Driver.Encryption/KmsRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public string KmsProvider
}
}

/// <summary>
/// The number of milliseconds to wait before sending this request.
/// </summary>
public int Sleep => (int)(Library.mongocrypt_kms_ctx_usleep(_id) / 1000);

/// <summary>
/// Gets the message to send to KMS.
/// </summary>
Expand All @@ -88,6 +93,12 @@ public Binary GetMessage()
return binary;
}

/// <summary>
/// Indicates a network-level failure.
/// </summary>
/// <returns>A boolean indicating whether the failed request may be retried.</returns>
public bool Fail() => Library.mongocrypt_kms_ctx_fail(_id);

/// <summary>
/// Feeds the response back to the libmongocrypt
/// </summary>
Expand Down
76 changes: 58 additions & 18 deletions src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)

private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken)
{
var requests = context.GetKmsMessageRequests();
foreach (var request in requests)
while (context.GetNextKmsMessageRequest() is { } request)
{
SendKmsRequest(request, cancellationToken);
}
requests.MarkDone();
context.MarkKmsDone();
}

private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken)
{
var requests = context.GetKmsMessageRequests();
foreach (var request in requests)
while (context.GetNextKmsMessageRequest() is { } request)
{
await SendKmsRequestAsync(request, cancellationToken).ConfigureAwait(false);
}
requests.MarkDone();
context.MarkKmsDone();
}

private void ProcessNeedMongoKeysState(CryptContext context, CancellationToken cancellationToken)
Expand Down Expand Up @@ -278,48 +276,90 @@ private static byte[] ProcessReadyState(CryptContext context)

private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using (var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation))
using (var binary = request.GetMessage())
try
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation);

var sleepMs = request.Sleep;
if (sleepMs > 0)
{
Thread.Sleep(sleepMs);
}

using var binary = request.GetMessage();
var requestBytes = binary.ToArray();
sslStream.Write(requestBytes, 0, requestBytes.Length);

while (request.BytesNeeded > 0)
{
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
var count = sslStream.Read(buffer, 0, buffer.Length);

if (count == 0)
{
throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
}

var responseBytes = new byte[count];
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
request.Feed(responseBytes);
}
}
catch (Exception ex) when (ex is IOException or SocketException)
{
if (!request.Fail())
{
throw;
}
}
}

private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation)
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using (var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false))
using (var binary = request.GetMessage())
try
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false);

var sleepMs = request.Sleep;
if (sleepMs > 0)
{
await Task.Delay(sleepMs, cancellation).ConfigureAwait(false);
}

using var binary = request.GetMessage();
var requestBytes = binary.ToArray();
await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false);

while (request.BytesNeeded > 0)
{
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);

if (count == 0)
{
throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
}

var responseBytes = new byte[count];
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
request.Feed(responseBytes);
}
}
catch (Exception ex) when (ex is IOException or SocketException)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure here which other kind of network-errors related exceptions we could get

Copy link
Contributor

@adelinowona adelinowona Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw some comment on the Drivers ticket for this feature that it only pertains to HTTP errors so for socket-level errors do we also retry on those? Other comments in the ticket seems to suggest we'll need to tell libmongocrypt to reset its state if we encounter socket errors while reading a kms response. Actually are http errors thrown as socket or IOException errors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that comment is a little bit misleading. So, as far as I have understood the situation is like this:

  • The HTTP error are actually handled internally by libmongocrypt when we "feed" the response bytes to it. It reads the response and understands if it's an HTTP response with an error code.
  • The network errors instead need to be recognized by us. When we do, we need to call mongocrypt_kms_ctx_fail (that is used inside the new Fail method) to notify libmongocrypt of the network error.
  • Internally, the handling for both errors is the same, as it retries the request a certain number of times

{
if (!request.Fail())
{
throw;
}
}
}

// nested type
Expand Down
24 changes: 24 additions & 0 deletions src/MongoDB.Driver.Encryption/Library.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ static Library()
_mongocrypt_ctx_setopt_query_type = new Lazy<Delegates.mongocrypt_ctx_setopt_query_type>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_setopt_query_type>(
("mongocrypt_ctx_setopt_query_type")), true);
_mongocrypt_setopt_retry_kms = new Lazy<Delegates.mongocrypt_setopt_retry_kms>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_setopt_retry_kms>(
("mongocrypt_setopt_retry_kms")), true);

_mongocrypt_ctx_status = new Lazy<Delegates.mongocrypt_ctx_status>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_status>(("mongocrypt_ctx_status")), true);
Expand Down Expand Up @@ -210,6 +213,11 @@ static Library()
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_destroy>(("mongocrypt_ctx_destroy")), true);
_mongocrypt_kms_ctx_get_kms_provider = new Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_get_kms_provider>(("mongocrypt_kms_ctx_get_kms_provider")), true);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these and the following methods/variables, is there a specific ordering? It does not seem to be completely alphabetical so for now I've put the new things on the bottom

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't a specific order for the methods, but many of them are grouped together. For instance, you'll notice that most of the setopt methods are clustered. For now, I recommend placing the mongocrypt_setopt_retry_kms method alongside the other setopt methods and leaving the other new methods at the bottom. You can add a TODO comment above the static constructor to organize the methods later (so you don't have unrelated changes to your PR). I'll handle that when addressing technical debt in Libmongocrypt later this quarter.

_mongocrypt_kms_ctx_usleep = new Lazy<Delegates.mongocrypt_kms_ctx_usleep>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_usleep>(("mongocrypt_kms_ctx_usleep")), true);
_mongocrypt_kms_ctx_fail = new Lazy<Delegates.mongocrypt_kms_ctx_fail>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_fail>(("mongocrypt_kms_ctx_fail")), true);
}

/// <summary>
Expand Down Expand Up @@ -287,6 +295,7 @@ public static string Version
internal static Delegates.mongocrypt_ctx_setopt_algorithm_range mongocrypt_ctx_setopt_algorithm_range => _mongocrypt_ctx_setopt_algorithm_range.Value;
internal static Delegates.mongocrypt_ctx_setopt_contention_factor mongocrypt_ctx_setopt_contention_factor => _mongocrypt_ctx_setopt_contention_factor.Value;
internal static Delegates.mongocrypt_ctx_setopt_query_type mongocrypt_ctx_setopt_query_type => _mongocrypt_ctx_setopt_query_type.Value;
internal static Delegates.mongocrypt_setopt_retry_kms mongocrypt_setopt_retry_kms => _mongocrypt_setopt_retry_kms.Value;

internal static Delegates.mongocrypt_ctx_state mongocrypt_ctx_state => _mongocrypt_ctx_state.Value;
internal static Delegates.mongocrypt_ctx_mongo_op mongocrypt_ctx_mongo_op => _mongocrypt_ctx_mongo_op.Value;
Expand All @@ -305,6 +314,9 @@ public static string Version
internal static Delegates.mongocrypt_ctx_destroy mongocrypt_ctx_destroy => _mongocrypt_ctx_destroy.Value;
internal static Delegates.mongocrypt_kms_ctx_get_kms_provider mongocrypt_kms_ctx_get_kms_provider => _mongocrypt_kms_ctx_get_kms_provider.Value;

internal static Delegates.mongocrypt_kms_ctx_usleep mongocrypt_kms_ctx_usleep => _mongocrypt_kms_ctx_usleep.Value;
internal static Delegates.mongocrypt_kms_ctx_fail mongocrypt_kms_ctx_fail => _mongocrypt_kms_ctx_fail.Value;

private static readonly Lazy<LibraryLoader> __loader = new Lazy<LibraryLoader>(
() => new LibraryLoader(), true);
private static readonly Lazy<Delegates.mongocrypt_version> _mongocrypt_version;
Expand Down Expand Up @@ -392,6 +404,10 @@ public static string Version
private static readonly Lazy<Delegates.mongocrypt_ctx_destroy> _mongocrypt_ctx_destroy;
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider> _mongocrypt_kms_ctx_get_kms_provider;

private static readonly Lazy<Delegates.mongocrypt_kms_ctx_usleep> _mongocrypt_kms_ctx_usleep;
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_fail> _mongocrypt_kms_ctx_fail;
private static readonly Lazy<Delegates.mongocrypt_setopt_retry_kms> _mongocrypt_setopt_retry_kms;

// nested types
internal enum StatusType
{
Expand Down Expand Up @@ -640,6 +656,9 @@ public delegate bool
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool mongocrypt_ctx_setopt_query_type(ContextSafeHandle ctx, [MarshalAs(UnmanagedType.LPStr)] string query_type, int length);

[return: MarshalAs(UnmanagedType.I1)]
public delegate bool mongocrypt_setopt_retry_kms(MongoCryptSafeHandle handle, bool enable);

public delegate CryptContext.StateCode mongocrypt_ctx_state(ContextSafeHandle handle);

[return: MarshalAs(UnmanagedType.I1)]
Expand Down Expand Up @@ -681,6 +700,11 @@ public delegate bool

public delegate void mongocrypt_ctx_destroy(IntPtr ptr);
public delegate IntPtr mongocrypt_kms_ctx_get_kms_provider(IntPtr handle, out uint length);

public delegate long mongocrypt_kms_ctx_usleep(IntPtr handle);

[return: MarshalAs(UnmanagedType.I1)]
public delegate bool mongocrypt_kms_ctx_fail(IntPtr handle);
}
}
}
10 changes: 5 additions & 5 deletions src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
</PropertyGroup>

<PropertyGroup>
<LibMongoCryptMacOsSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptMacOsSourceUrl>
<LibMongoCryptUbuntuX64SourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptUbuntuX64SourceUrl>
<LibMongoCryptUbuntuARM64SourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptUbuntuARM64SourceUrl>
<LibMongoCryptAlpineSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptAlpineSourceUrl>
<LibMongoCryptWindowsSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptWindowsSourceUrl>
<LibMongoCryptMacOsSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptMacOsSourceUrl>
<LibMongoCryptUbuntuX64SourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptUbuntuX64SourceUrl>
<LibMongoCryptUbuntuARM64SourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptUbuntuARM64SourceUrl>
<LibMongoCryptAlpineSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptAlpineSourceUrl>
<LibMongoCryptWindowsSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptWindowsSourceUrl>
</PropertyGroup>

<Target Name="DownloadNativeBinaries_linux_x64" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/linux/x64/libmongocrypt.so')">
Expand Down
13 changes: 6 additions & 7 deletions tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ public void TestGetKmsProviderName(string kmsName)
using (var cryptClient = CryptClientFactory.Create(cryptOptions))
using (var context = cryptClient.StartCreateDataKeyContext(keyId))
{
var request = context.GetKmsMessageRequests().Single();
var request = context.GetNextKmsMessageRequest();
request.KmsProvider.Should().Be(kmsName);
}
}
Expand Down Expand Up @@ -634,22 +634,21 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs

case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS:
{
var requests = context.GetKmsMessageRequests();
foreach (var req in requests)
while (context.GetNextKmsMessageRequest() is { } request)
{
using var binary = req.GetMessage();
using var binary = request.GetMessage();
_output.WriteLine("Key Document: " + binary);
var postRequest = binary.ToString();
// TODO: add different hosts handling
postRequest.Should().Contain("Host:kms.us-east-1.amazonaws.com"); // only AWS

var reply = ReadHttpTestFile(isKmsDecrypt ? "kms-decrypt-reply.txt" : "kms-encrypt-reply.txt");
_output.WriteLine("Reply: " + reply);
req.Feed(Encoding.UTF8.GetBytes(reply));
req.BytesNeeded.Should().Be(0);
request.Feed(Encoding.UTF8.GetBytes(reply));
request.BytesNeeded.Should().Be(0);
}

requests.MarkDone();
context.MarkKmsDone();
return (CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS, null, null);
}

Expand Down
Loading