Skip to content

Commit

Permalink
Update to new API, create connection in callback, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK committed Jul 21, 2022
1 parent e324e48 commit 355dd72
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.CancellationToken.
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.CancellationToken.set -> void
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.ClientHelloInfo.get -> System.Net.Security.SslClientHelloInfo
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.ClientHelloInfo.set -> void
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.Connection.get -> Microsoft.AspNetCore.Connections.ConnectionContext!
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.Connection.get -> Microsoft.AspNetCore.Connections.BaseConnectionContext!
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.Connection.set -> void
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.State.get -> object?
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.State.set -> void
Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext.TlsConnectionCallbackContext() -> void
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions.ApplicationProtocols.get -> System.Collections.Generic.List<System.Net.Security.SslApplicationProtocol>!
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions.ApplicationProtocols.set -> void
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions.OnConnection.get -> System.Func<Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext!, System.Threading.Tasks.ValueTask<System.Net.Security.SslServerAuthenticationOptions!>>!
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions.OnConnection.get -> System.Func<Microsoft.AspNetCore.Connections.TlsConnectionCallbackContext!, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask<System.Net.Security.SslServerAuthenticationOptions!>>!
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions.OnConnection.set -> void
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions.OnConnectionState.get -> object?
Microsoft.AspNetCore.Connections.TlsConnectionCallbackOptions.OnConnectionState.set -> void
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,9 @@ public class TlsConnectionCallbackContext
/// </summary>
public object? State { get; set; }

/// <summary>
/// The token to monitor for cancellation requests.
/// </summary>
public CancellationToken CancellationToken { get; set; }

/// <summary>
/// Information about an individual connection.
/// </summary>
public ConnectionContext Connection { get; set; } = default!;
public BaseConnectionContext Connection { get; set; } = default!;
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.Net.Security;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.AspNetCore.Connections;
Expand All @@ -17,7 +18,7 @@ public class TlsConnectionCallbackOptions
/// <summary>
/// The callback to invoke per connection. This property is required.
/// </summary>
public Func<TlsConnectionCallbackContext, ValueTask<SslServerAuthenticationOptions>> OnConnection { get; set; } = default!;
public Func<TlsConnectionCallbackContext, CancellationToken, ValueTask<SslServerAuthenticationOptions>> OnConnection { get; set; } = default!;

/// <summary>
/// Optional application state to flow to the <see cref="OnConnection"/> callback.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#nullable enable

using System.IO.Pipelines;
using System.Linq;
using System.Net;
using System.Net.Security;
Expand Down Expand Up @@ -63,7 +64,7 @@ public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDe
features.Set(new TlsConnectionCallbackOptions
{
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols ?? new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 },
OnConnection = context => ValueTask.FromResult(sslServerAuthenticationOptions),
OnConnection = (context, cancellationToken) => ValueTask.FromResult(sslServerAuthenticationOptions),
OnConnectionState = null,
});
}
Expand All @@ -72,14 +73,14 @@ public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDe
features.Set(new TlsConnectionCallbackOptions
{
ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 },
OnConnection = context =>
OnConnection = (context, cancellationToken) =>
{
return listenOptions.HttpsCallbackOptions.OnConnection(new TlsHandshakeCallbackContext
{
ClientHelloInfo = context.ClientHelloInfo,
CancellationToken = context.CancellationToken,
CancellationToken = cancellationToken,
State = context.State,
Connection = context.Connection,
Connection = new ConnectionContextAdapter(context.Connection),
});
},
OnConnectionState = listenOptions.HttpsCallbackOptions.OnConnectionState,
Expand All @@ -91,6 +92,49 @@ public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDe
return transport.EndPoint;
}

/// <summary>
/// TlsHandshakeCallbackContext.Connection is ConnectionContext but QUIC connection only implements BaseConnectionContext.
/// </summary>
private sealed class ConnectionContextAdapter : ConnectionContext
{
private readonly BaseConnectionContext _inner;

public ConnectionContextAdapter(BaseConnectionContext inner) => _inner = inner;

public override IDuplexPipe Transport
{
get => throw new NotSupportedException("Not supported by HTTP/3 connections.");
set => throw new NotSupportedException("Not supported by HTTP/3 connections.");
}
public override string ConnectionId
{
get => _inner.ConnectionId;
set => _inner.ConnectionId = value;
}
public override IFeatureCollection Features => _inner.Features;
public override IDictionary<object, object?> Items
{
get => _inner.Items;
set => _inner.Items = value;
}
public override EndPoint? LocalEndPoint
{
get => _inner.LocalEndPoint;
set => _inner.LocalEndPoint = value;
}
public override EndPoint? RemoteEndPoint
{
get => _inner.RemoteEndPoint;
set => _inner.RemoteEndPoint = value;
}
public override CancellationToken ConnectionClosed
{
get => _inner.ConnectionClosed;
set => _inner.ConnectionClosed = value;
}
public override ValueTask DisposeAsync() => _inner.DisposeAsync();
}

private void StartAcceptLoop<T>(IConnectionListener<T> connectionListener, Func<T, Task> connectionDelegate, EndpointConfig? endpointConfig) where T : BaseConnectionContext
{
var transportConnectionManager = new TransportConnectionManager(_serviceContext.ConnectionManager);
Expand Down
13 changes: 12 additions & 1 deletion src/Servers/Kestrel/Kestrel.slnf
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
"solution": {
"path": "..\\..\\..\\AspNetCore.sln",
"projects": [
"src\\DataProtection\\Abstractions\\src\\Microsoft.AspNetCore.DataProtection.Abstractions.csproj",
"src\\DataProtection\\Cryptography.Internal\\src\\Microsoft.AspNetCore.Cryptography.Internal.csproj",
"src\\DataProtection\\DataProtection\\src\\Microsoft.AspNetCore.DataProtection.csproj",
"src\\DefaultBuilder\\src\\Microsoft.AspNetCore.csproj",
"src\\Extensions\\Features\\src\\Microsoft.Extensions.Features.csproj",
"src\\Extensions\\Features\\test\\Microsoft.Extensions.Features.Tests.csproj",
"src\\Hosting\\Abstractions\\src\\Microsoft.AspNetCore.Hosting.Abstractions.csproj",
"src\\Hosting\\Hosting\\src\\Microsoft.AspNetCore.Hosting.csproj",
"src\\Hosting\\Server.Abstractions\\src\\Microsoft.AspNetCore.Hosting.Server.Abstractions.csproj",
"src\\Http\\Authentication.Abstractions\\src\\Microsoft.AspNetCore.Authentication.Abstractions.csproj",
"src\\Http\\Authentication.Core\\src\\Microsoft.AspNetCore.Authentication.Core.csproj",
"src\\Http\\Headers\\src\\Microsoft.Net.Http.Headers.csproj",
"src\\Http\\Http.Abstractions\\src\\Microsoft.AspNetCore.Http.Abstractions.csproj",
"src\\Http\\Http.Extensions\\src\\Microsoft.AspNetCore.Http.Extensions.csproj",
Expand All @@ -21,8 +27,12 @@
"src\\Middleware\\HostFiltering\\src\\Microsoft.AspNetCore.HostFiltering.csproj",
"src\\Middleware\\HttpOverrides\\src\\Microsoft.AspNetCore.HttpOverrides.csproj",
"src\\ObjectPool\\src\\Microsoft.Extensions.ObjectPool.csproj",
"src\\Security\\Authentication\\Core\\src\\Microsoft.AspNetCore.Authentication.csproj",
"src\\Security\\Authorization\\Core\\src\\Microsoft.AspNetCore.Authorization.csproj",
"src\\Security\\Authorization\\Policy\\src\\Microsoft.AspNetCore.Authorization.Policy.csproj",
"src\\Servers\\Connections.Abstractions\\src\\Microsoft.AspNetCore.Connections.Abstractions.csproj",
"src\\Servers\\IIS\\IISIntegration\\src\\Microsoft.AspNetCore.Server.IISIntegration.csproj",
"src\\Servers\\IIS\\IIS\\src\\Microsoft.AspNetCore.Server.IIS.csproj",
"src\\Servers\\Kestrel\\Core\\src\\Microsoft.AspNetCore.Server.Kestrel.Core.csproj",
"src\\Servers\\Kestrel\\Core\\test\\Microsoft.AspNetCore.Server.Kestrel.Core.Tests.csproj",
"src\\Servers\\Kestrel\\Kestrel\\src\\Microsoft.AspNetCore.Server.Kestrel.csproj",
Expand All @@ -47,7 +57,8 @@
"src\\Servers\\Kestrel\\test\\Sockets.BindTests\\Sockets.BindTests.csproj",
"src\\Servers\\Kestrel\\test\\Sockets.FunctionalTests\\Sockets.FunctionalTests.csproj",
"src\\Servers\\Kestrel\\tools\\CodeGenerator\\CodeGenerator.csproj",
"src\\Testing\\src\\Microsoft.AspNetCore.Testing.csproj"
"src\\Testing\\src\\Microsoft.AspNetCore.Testing.csproj",
"src\\WebEncoders\\src\\Microsoft.Extensions.WebEncoders.csproj"
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ internal bool TryReturnStream(QuicStreamContext stream)
return false;
}

internal QuicConnection GetInnerConnection()
{
return _connection;
}

private void RemoveExpiredStreams()
{
lock (_poolLock)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Net;
using System.Net.Quic;
using System.Net.Security;
Expand All @@ -21,6 +22,7 @@ internal sealed class QuicConnectionListener : IMultiplexedConnectionListener, I
private readonly QuicListenerOptions _quicListenerOptions;
private bool _disposed;
private QuicListener? _listener;
private QuicConnectionContext? _currentAcceptingConnection;

public QuicConnectionListener(
QuicTransportOptions options,
Expand Down Expand Up @@ -53,13 +55,15 @@ public QuicConnectionListener(
ListenBacklog = options.Backlog,
ConnectionOptionsCallback = async (connection, helloInfo, cancellationToken) =>
{
var serverAuthenticationOptions = await _tlsConnectionOptions.OnConnection(new TlsConnectionCallbackContext
_currentAcceptingConnection = new QuicConnectionContext(connection, _context);
var context = new TlsConnectionCallbackContext
{
CancellationToken = cancellationToken,
ClientHelloInfo = helloInfo,
State = _tlsConnectionOptions.OnConnectionState,
Connection = null!,
});
Connection = _currentAcceptingConnection,
};
var serverAuthenticationOptions = await _tlsConnectionOptions.OnConnection(context, cancellationToken);
// If the callback didn't set protocols then use the listener's list of protocols.
if (serverAuthenticationOptions.ApplicationProtocols == null)
Expand Down Expand Up @@ -124,7 +128,11 @@ public async ValueTask CreateListenerAsync()
try
{
var quicConnection = await _listener.AcceptConnectionAsync(cancellationToken);
var connectionContext = new QuicConnectionContext(quicConnection, _context);
var connectionContext = _currentAcceptingConnection;

// Verify the connection context was created and set correctly.
Debug.Assert(connectionContext != null);
Debug.Assert(connectionContext.GetInnerConnection() == quicConnection);

QuicLog.AcceptedConnection(_log, connectionContext);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public async Task AcceptAsync_NoCertificateOrApplicationProtocol_Log()
new TlsConnectionCallbackOptions
{
ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 },
OnConnection = context =>
OnConnection = (context, cancellationToken) =>
{
var options = new SslServerAuthenticationOptions();
options.ApplicationProtocols = new List<SslApplicationProtocol>();
Expand Down Expand Up @@ -149,7 +149,7 @@ public async Task AcceptAsync_NoApplicationProtocolsInCallback_DefaultToConnecti
new TlsConnectionCallbackOptions
{
ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 },
OnConnection = context =>
OnConnection = (context, cancellationToken) =>
{
var options = new SslServerAuthenticationOptions();
options.ServerCertificate = TestResources.GetTestCertificate();
Expand All @@ -168,4 +168,37 @@ public async Task AcceptAsync_NoApplicationProtocolsInCallback_DefaultToConnecti
// Assert
Assert.Equal(SslApplicationProtocol.Http3, clientConnection.NegotiatedApplicationProtocol);
}

[ConditionalFact]
[MsQuicSupported]
public async Task AcceptAsync_TlsCallback_ConnectionContextInArguments()
{
// Arrange
BaseConnectionContext connectionContext = null;
await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(
new TlsConnectionCallbackOptions
{
ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 },
OnConnection = (context, cancellationToken) =>
{
var options = new SslServerAuthenticationOptions();
options.ServerCertificate = TestResources.GetTestCertificate();
connectionContext = context.Connection;
return ValueTask.FromResult(options);
}
},
LoggerFactory);

// Act
var acceptTask = connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();

var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);

await using var clientConnection = await QuicConnection.ConnectAsync(options).DefaultTimeout();

// Assert
Assert.NotNull(connectionContext);
}
}
2 changes: 1 addition & 1 deletion src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static FeatureCollection CreateBindAsyncFeatures(bool clientCertificateRe
features.Set(new TlsConnectionCallbackOptions
{
ApplicationProtocols = sslServerAuthenticationOptions.ApplicationProtocols,
OnConnection = context => ValueTask.FromResult(sslServerAuthenticationOptions)
OnConnection = (context, cancellationToken) => ValueTask.FromResult(sslServerAuthenticationOptions)
});

return features;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ public async Task Http3_UseHttpsNoArgsWithDefaultCertificate_UseDefaultCertifica
var sslOptions = bindFeatures.Get<TlsConnectionCallbackOptions>();
Assert.NotNull(sslOptions);

var sslServerAuthenticationOptions = await sslOptions.OnConnection(new TlsConnectionCallbackContext());
var sslServerAuthenticationOptions = await sslOptions.OnConnection(new TlsConnectionCallbackContext(), CancellationToken.None);
Assert.Equal(_x509Certificate2, sslServerAuthenticationOptions.ServerCertificate);
}

Expand Down Expand Up @@ -486,7 +486,7 @@ public async Task Http3_ConfigureHttpsDefaults_Works()
var tlsOptions = bindFeatures.Get<TlsConnectionCallbackOptions>();
Assert.NotNull(tlsOptions);

var sslServerAuthenticationOptions = await tlsOptions.OnConnection(new TlsConnectionCallbackContext());
var sslServerAuthenticationOptions = await tlsOptions.OnConnection(new TlsConnectionCallbackContext(), CancellationToken.None);
Assert.Equal(_x509Certificate2, sslServerAuthenticationOptions.ServerCertificate);
}

Expand Down Expand Up @@ -585,7 +585,7 @@ public async Task Http3_ServerOptionsSelectionCallback_Works()
listenOptions.Protocols = HttpProtocols.Http3;
listenOptions.UseHttps((SslStream stream, SslClientHelloInfo clientHelloInfo, object state, CancellationToken cancellationToken) =>
{
return ValueTask.FromResult((new SslServerAuthenticationOptions()));
return ValueTask.FromResult(new SslServerAuthenticationOptions());
}, state: testState);
});
},
Expand Down
Loading

0 comments on commit 355dd72

Please sign in to comment.