Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TLS Resume with client certificates on Linux #102656

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,39 @@ internal static partial class OpenSsl
private const string TlsCacheSizeCtxName = "System.Net.Security.TlsCacheSize";
private const string TlsCacheSizeEnvironmentVariable = "DOTNET_SYSTEM_NET_SECURITY_TLSCACHESIZE";
private const SslProtocols FakeAlpnSslProtocol = (SslProtocols)1; // used to distinguish server sessions with ALPN
private static readonly ConcurrentDictionary<SslProtocols, SafeSslContextHandle> s_clientSslContexts = new ConcurrentDictionary<SslProtocols, SafeSslContextHandle>();
private static readonly ConcurrentDictionary<SslContextCacheKey, SafeSslContextHandle> s_clientSslContexts = new ConcurrentDictionary<SslContextCacheKey, SafeSslContextHandle>();

internal readonly struct SslContextCacheKey : IEquatable<SslContextCacheKey>
{
public readonly byte[]? CertificateThumbprint;
public readonly SslProtocols SslProtocols;

public SslContextCacheKey(SslProtocols sslProtocols, byte[]? certificateThumbprint)
{
SslProtocols = sslProtocols;
CertificateThumbprint = certificateThumbprint;
}

public override bool Equals(object? obj) => obj is SslContextCacheKey key && Equals(key);

public bool Equals(SslContextCacheKey other) =>
wfurt marked this conversation as resolved.
Show resolved Hide resolved
SslProtocols == other.SslProtocols &&
(CertificateThumbprint == null && other.CertificateThumbprint == null ||
CertificateThumbprint != null && other.CertificateThumbprint != null && CertificateThumbprint.AsSpan().SequenceEqual(other.CertificateThumbprint));

public override int GetHashCode()
rzikm marked this conversation as resolved.
Show resolved Hide resolved
{
HashCode hash = default;

hash.Add(SslProtocols);
if (CertificateThumbprint != null)
{
hash.AddBytes(CertificateThumbprint);
}

return hash.ToHashCode();
}
}

#region internal methods
internal static SafeChannelBindingHandle? QueryChannelBinding(SafeSslHandle context, ChannelBindingKind bindingType)
Expand Down Expand Up @@ -188,7 +220,7 @@ internal static unsafe SafeSslContextHandle AllocateSslContext(SslAuthentication
Interop.Ssl.SslCtxSetAlpnSelectCb(sslCtx, &AlpnServerSelectCallback, IntPtr.Zero);
}

if (sslAuthenticationOptions.CertificateContext != null)
if (sslAuthenticationOptions.CertificateContext != null && sslAuthenticationOptions.IsServer)
{
SetSslCertificate(sslCtx, sslAuthenticationOptions.CertificateContext.CertificateHandle, sslAuthenticationOptions.CertificateContext.KeyHandle);

Expand Down Expand Up @@ -269,13 +301,12 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth
{
// We don't support client resume on old OpenSSL versions.
// We don't want to try on empty TargetName since that is our key.
// And we don't want to mess up with client authentication. It may be possible
// but it seems safe to get full new session.
// If we already have CertificateContext, then we know which cert the user wants to use and we can cache.
// The only client auth scenario where we can't cache is when user provides a cert callback and we don't know
// beforehand which cert will be used. and wan't to avoid resuming session created with different certificate.
if (!Interop.Ssl.Capabilities.Tls13Supported ||
string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) ||
sslAuthenticationOptions.CertificateContext != null ||
sslAuthenticationOptions.ClientCertificates?.Count > 0 ||
sslAuthenticationOptions.CertSelectionDelegate != null)
(sslAuthenticationOptions.CertificateContext == null && sslAuthenticationOptions.CertSelectionDelegate != null))
{
cacheSslContext = false;
}
Expand All @@ -300,8 +331,8 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth
}
else
{

s_clientSslContexts.TryGetValue(protocols, out sslCtxHandle);
var key = new SslContextCacheKey(protocols, sslAuthenticationOptions.CertificateContext?.TargetCertificate.GetCertHash());
rzikm marked this conversation as resolved.
Show resolved Hide resolved
s_clientSslContexts.TryGetValue(key, out sslCtxHandle);
}
}

Expand All @@ -312,9 +343,18 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth

if (cacheSslContext)
{
bool added = sslAuthenticationOptions.IsServer ?
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle) :
s_clientSslContexts.TryAdd(protocols, newCtxHandle);
bool added;

if (sslAuthenticationOptions.IsServer)
{
added = sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle);
}
else
{
var key = new SslContextCacheKey(protocols, sslAuthenticationOptions.CertificateContext?.TargetCertificate.GetCertHash());
added = s_clientSslContexts.TryAdd(key, newCtxHandle);
rzikm marked this conversation as resolved.
Show resolved Hide resolved
}

if (added)
{
newCtxHandle = null;
Expand Down Expand Up @@ -373,7 +413,8 @@ internal static SafeSslHandle AllocateSslHandle(SslAuthenticationOptions sslAuth

// relevant to TLS 1.3 only: if user supplied a client cert or cert callback,
// advertise that we are willing to send the certificate post-handshake.
if (sslAuthenticationOptions.ClientCertificates?.Count > 0 ||
if (sslAuthenticationOptions.CertificateContext != null ||
sslAuthenticationOptions.ClientCertificates?.Count > 0 ||
sslAuthenticationOptions.CertSelectionDelegate != null)
{
Ssl.SslSetPostHandshakeAuth(sslHandle, 1);
Expand Down Expand Up @@ -708,6 +749,12 @@ private static unsafe int NewSessionCallback(IntPtr ssl, IntPtr session)
Debug.Assert(ssl != IntPtr.Zero);
Debug.Assert(session != IntPtr.Zero);

// remember if the session used a certificate, this information is used after
// session resumption, the pointer is not being dereferenced and the refcount
// is not going to be manipulated.
IntPtr cert = Interop.Ssl.SslGetCertificate(ssl);
Interop.Ssl.SslSessionSetData(session, cert);

IntPtr ptr = Ssl.SslGetData(ssl);
if (ptr != IntPtr.Zero)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ internal static unsafe ReadOnlySpan<byte> SslGetAlpnSelected(SafeSslHandle ssl)
[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertificate")]
internal static partial IntPtr SslGetPeerCertificate(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetCertificate")]
internal static partial IntPtr SslGetCertificate(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetCertificate")]
internal static partial IntPtr SslGetCertificate(IntPtr ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertChain")]
internal static partial SafeSharedX509StackHandle SslGetPeerCertChain(SafeSslHandle ssl);

Expand All @@ -129,6 +135,9 @@ internal static unsafe ReadOnlySpan<byte> SslGetAlpnSelected(SafeSslHandle ssl)
[return: MarshalAs(UnmanagedType.Bool)]
internal static partial bool SslSessionReused(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetSession")]
internal static partial IntPtr SslGetSession(SafeSslHandle ssl);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetClientCAList")]
private static partial SafeSharedX509NameStackHandle SslGetClientCAList_private(SafeSslHandle ssl);

Expand Down Expand Up @@ -182,6 +191,12 @@ internal static unsafe ReadOnlySpan<byte> SslGetAlpnSelected(SafeSslHandle ssl)
[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetHostname")]
internal static partial int SessionSetHostname(IntPtr session, IntPtr name);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionGetData")]
internal static partial IntPtr SslSessionGetData(IntPtr session);

[LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetData")]
internal static partial void SslSessionSetData(IntPtr session, IntPtr val);

internal static class Capabilities
{
// needs separate type (separate static cctor) to be sure OpenSSL is initialized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,43 @@ internal static SslPolicyErrors VerifyCertificateProperties(
return result;
}

// This is only called when we selected local client certificate.
// Currently this is only when OpenSSL needs it because peer asked.
internal static bool IsLocalCertificateUsed(SafeFreeCredentials? _1, SafeDeleteContext? _2) => true;
internal static bool IsLocalCertificateUsed(SafeFreeCredentials? _1, SafeDeleteContext? ctx)
{
if (ctx is not SafeSslHandle ssl)
{
return false;
}

if (!Interop.Ssl.SslSessionReused(ssl))
{
// Fresh session, we set the certificate on the SSL object only
// if the peer explicitly requested it.
return Interop.Ssl.SslGetCertificate(ssl) != IntPtr.Zero;
}

// resumed session, we keep the information about cert being used in the SSL_SESSION
// object's ex_data
bool addref = false;
try
{
// make sure the ssl is not freed while we accessing its SSL_SESSION
// this makes sure the `session` pointer is valid during this call
// despite not being a SafeHandle.
ssl.DangerousAddRef(ref addref);
bartonjs marked this conversation as resolved.
Show resolved Hide resolved

// the information about certificate usage is stored in the session ex data
IntPtr session = Interop.Ssl.SslGetSession(ssl);
Debug.Assert(session != IntPtr.Zero);
return Interop.Ssl.SslSessionGetData(session) != IntPtr.Zero;
}
finally
{
if (addref)
{
ssl.DangerousRelease();
}
}
}

//
// Used only by client SSL code, never returns null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static Exception GetException(SecurityStatusPal status)
return status.Exception ?? new Interop.OpenSsl.SslException((int)status.ErrorCode);
}

internal const bool StartMutualAuthAsAnonymous = true;
internal const bool StartMutualAuthAsAnonymous = false;
internal const bool CanEncryptEmptyMessage = false;

public static void VerifyPackageInfo()
Expand Down Expand Up @@ -168,8 +168,8 @@ public static bool TryUpdateClintCertificate(
return true;
}

private static ProtocolToken HandshakeInternal(ref SafeDeleteSslContext? context,
ReadOnlySpan<byte> inputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
private static ProtocolToken HandshakeInternal(ref SafeDeleteSslContext? context,
ReadOnlySpan<byte> inputBuffer, SslAuthenticationOptions sslAuthenticationOptions)
{
ProtocolToken token = default;
token.RentBuffer = true;
Expand All @@ -186,8 +186,20 @@ private static ProtocolToken HandshakeInternal(ref SafeDeleteSslContext? context
{
// this should happen only for clients
Debug.Assert(sslAuthenticationOptions.IsClient);
token.Status = new SecurityStatusPal(errorCode);
return token;

// if we don't have a clien certificate ready, bubble up so
rzikm marked this conversation as resolved.
Show resolved Hide resolved
// that the certificate selection routine runs again. This
// happens if the first call to LocalCertificateSelectionCallback
// returns null.
if (sslAuthenticationOptions.CertificateContext == null)
{
token.Status = new SecurityStatusPal(SecurityStatusPalErrorCode.CredentialsNeeded);
return token;
}

// set the cert and continue
TryUpdateClintCertificate(null, context, sslAuthenticationOptions);
errorCode = Interop.OpenSsl.DoSslHandshake((SafeSslHandle)context, ReadOnlySpan<byte>.Empty, ref token);
}

// sometimes during renegotiation processing message does not yield new output.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace System.Net.Security.Tests
{
using Configuration = System.Net.Test.Common.Configuration;

[PlatformSpecific(TestPlatforms.Windows | TestPlatforms.Linux)]
rzikm marked this conversation as resolved.
Show resolved Hide resolved
public class SslStreamTlsResumeTests
{
private static FieldInfo connectionInfo = typeof(SslStream).GetField(
Expand All @@ -29,7 +30,6 @@ private bool CheckResumeFlag(SslStream ssl)
[ConditionalTheory]
[InlineData(true)]
[InlineData(false)]
[PlatformSpecific(TestPlatforms.Windows | TestPlatforms.Linux)]
public async Task SslStream_ClientDisableTlsResume_Succeeds(bool testClient)
{
SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
}
else
{
Assert.Null(server.RemoteCertificate);
Assert.Null(server.RemoteCertificate);
}
};
}
Expand Down Expand Up @@ -320,7 +320,7 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
}
else
{
Assert.Null(server.RemoteCertificate);
Assert.Null(server.RemoteCertificate);
}
};
}
Expand Down Expand Up @@ -357,7 +357,7 @@ public async Task SslStream_ResumedSessionsCallbackMaybeSet_IsMutuallyAuthentica

if (expectMutualAuthentication)
{
clientOptions.LocalCertificateSelectionCallback = (s, t, l, r, a) => _clientCertificate;
clientOptions.LocalCertificateSelectionCallback = (s, t, l, r, a) => _clientCertificate;
}
else
{
Expand All @@ -378,7 +378,7 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
}
else
{
Assert.Null(server.RemoteCertificate);
Assert.Null(server.RemoteCertificate);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,14 +340,18 @@ static const Entry s_cryptoNative[] =
DllImportEntry(CryptoNative_SslGetFinished)
DllImportEntry(CryptoNative_SslGetPeerCertChain)
DllImportEntry(CryptoNative_SslGetPeerCertificate)
DllImportEntry(CryptoNative_SslGetCertificate)
DllImportEntry(CryptoNative_SslGetPeerFinished)
DllImportEntry(CryptoNative_SslGetServerName)
DllImportEntry(CryptoNative_SslGetSession)
DllImportEntry(CryptoNative_SslGetVersion)
DllImportEntry(CryptoNative_SslRead)
DllImportEntry(CryptoNative_SslSessionFree)
DllImportEntry(CryptoNative_SslSessionGetHostname)
DllImportEntry(CryptoNative_SslSessionSetHostname)
DllImportEntry(CryptoNative_SslSessionReused)
DllImportEntry(CryptoNative_SslSessionGetData)
DllImportEntry(CryptoNative_SslSessionSetData)
DllImportEntry(CryptoNative_SslSetAcceptState)
DllImportEntry(CryptoNative_SslSetAlpnProtos)
DllImportEntry(CryptoNative_SslSetBio)
Expand Down
Loading
Loading