diff --git a/docs/guide/messaging/transports/signalr.md b/docs/guide/messaging/transports/signalr.md index ff6ee6b70..281240671 100644 --- a/docs/guide/messaging/transports/signalr.md +++ b/docs/guide/messaging/transports/signalr.md @@ -594,6 +594,49 @@ public async Task receive_message_from_a_client() *Conveniently enough as I write this documentation today using existing test code, Hollywood Brown had a huge game last night. Go Chiefs!* +### Authorization + +If you are connecting to a hub requiring authorization (for example using the `[Authorize]` attribute) you need to provide a token provider. + + + +```cs +var host = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.ServiceName = serviceName; + + // Configure a client with an access token provider. You get an instance of `IServiceProvider` + // if you need access to additional services, for example accessing `IConfiguration` + opts.UseClientToSignalR(Port, accessTokenProvider: (sp) => () => Task.FromResult(accessToken)); + + opts.Publish(x => + { + x.MessagesImplementing(); + x.ToSignalRWithClient(Port); + }); + + opts.Publish(x => + { + x.MessagesImplementing(); + + // You can also configure the access token provider when configuring + // the message publishing. Last configuration wins and applies to the + // client URL, *not* the message type + x.ToSignalRWithClient(Port, accessTokenProvider: (sp) => () => + { + var configuration = sp.GetRequiredService(); + var configuredToken = configuration.GetValue("SignalR:AccessToken") + // Fall back to the token passed in when testing + ?? accessToken; + return Task.FromResult(configuredToken); + }); + }); + }).StartAsync(); +``` +snippet source | anchor + + ## Web Socket "Sagas" ::: info diff --git a/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs b/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs index b971af240..5adb274c1 100644 --- a/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs +++ b/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs @@ -1,10 +1,13 @@ -using System.Diagnostics; +using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; -using Wolverine.Runtime; +using Microsoft.Extensions.Options; +using System.Diagnostics; +using System.Security.Claims; +using System.Text.Encodings.Web; using Wolverine.SignalR.Client; using Wolverine.Util; @@ -132,6 +135,19 @@ public async Task InitializeAsync() opts.ListenLocalhost(Port); }); + builder.Services + .AddAuthentication("TestAuthScheme") + .AddScheme("TestAuthScheme", null, null); + + builder.Services.AddAuthorizationCore(options => + { + options.AddPolicy("TestToken", policyBuilder => + { + policyBuilder.AuthenticationSchemes.Add("TestAuthScheme"); + policyBuilder.RequireAuthenticatedUser(); + }); + }); + #region sample_custom_signalr_hub builder.Services.AddSignalR(); builder.Host.UseWolverine(opts => @@ -162,25 +178,42 @@ public async Task InitializeAsync() // This starts up a new host to act as a client to the SignalR // server for testing - public async Task StartClientHost(string serviceName = "Client") + public async Task StartClientHost(string serviceName = "Client", string accessToken = "supersecrettoken") { + #region sample_signalr_authentication var host = await Host.CreateDefaultBuilder() .UseWolverine(opts => { opts.ServiceName = serviceName; - opts.UseClientToSignalR(Port); - - opts.PublishMessage().ToSignalRWithClient(Port); - - opts.PublishMessage().ToSignalRWithClient(Port); + // Configure a client with an access token provider. You get an instance of `IServiceProvider` + // if you need access to additional services, for example accessing `IConfiguration` + opts.UseClientToSignalR(Port, accessTokenProvider: (sp) => () => Task.FromResult(accessToken)); opts.Publish(x => { x.MessagesImplementing(); x.ToSignalRWithClient(Port); }); + + opts.Publish(x => + { + x.MessagesImplementing(); + + // You can also configure the access token provider when configuring + // the message publishing. Last configuration wins and applies to the + // client URL, *not* the message type + x.ToSignalRWithClient(Port, accessTokenProvider: (sp) => () => + { + var configuration = sp.GetRequiredService(); + var configuredToken = configuration.GetValue("SignalR:AccessToken") + // Fall back to the token passed in when testing + ?? accessToken; + return Task.FromResult(configuredToken); + }); + }); }).StartAsync(); + #endregion _clientHosts.Add(host); @@ -203,6 +236,9 @@ public record ToFirst(string Name) : WebSocketMessage; public record FromFirst(string Name) : WebSocketMessage; public record ToSecond(string Name) : WebSocketMessage; public record FromSecond(string Name) : WebSocketMessage; +public interface AuthenticatedWebSocketMessage : WebSocketMessage +{ +} public static class WebSocketMessageHandler { @@ -212,3 +248,30 @@ public static class WebSocketMessageHandler public static void Handle(FromSecond m) => Debug.WriteLine("Got " + m); } +internal class TestAuthenticationHandler : AuthenticationHandler +{ + public TestAuthenticationHandler( + IOptionsMonitor options, + Microsoft.Extensions.Logging.ILoggerFactory logger, + UrlEncoder encoder) + : base(options, logger, encoder) + { + } + + protected override Task HandleAuthenticateAsync() + { + var authToken = Context.Request.Headers.Authorization.ToString().Split(" ").Last(); + if (authToken != "supersecrettoken") + return Task.FromResult(AuthenticateResult.Fail("Invalid token")); + + var identity = new ClaimsIdentity([new Claim(ClaimTypes.NameIdentifier, "wolverine")], Scheme.Name); + var principal = new ClaimsPrincipal(identity); + var ticket = new AuthenticationTicket(principal, Scheme.Name); + + return Task.FromResult(AuthenticateResult.Success(ticket)); + } +} + +public class TestAuthenticationOptions : AuthenticationSchemeOptions +{ +} diff --git a/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs b/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs index 39ec94fbd..6e5b39ff6 100644 --- a/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs +++ b/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs @@ -1,5 +1,5 @@ using JasperFx.Core; -using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.Logging; using Shouldly; using Wolverine.SignalR.Client; @@ -10,7 +10,7 @@ namespace Wolverine.SignalR.Tests; public class custom_hub : WebSocketTestContextWithCustomHub { - public static List ReceivedJson = new(); + public static List ReceivedJson { get; } = new(); public override async Task DisposeAsync() { @@ -64,6 +64,98 @@ public async Task publishing_from_client_uses_custom_hub() } } +public class custom_hub_with_authentication : WebSocketTestContextWithCustomHub +{ + public static List ReceivedJson { get; } = new(); + + public override async Task DisposeAsync() + { + ReceivedJson.Clear(); + await base.DisposeAsync(); + } + + [Fact] + public async Task client_can_receive_from_authenticated_hub() + { + var client = await StartClientHost(); + + var tracked = await theWebApp + .TrackActivity() + .IncludeExternalTransports() + .AlsoTrack(client) + .Timeout(10.Seconds()) + .SendMessageAndWaitAsync(new FromSecond("Hollywood Brown")); + + var record = tracked.Received.SingleRecord(); + record.ServiceName.ShouldBe("Client"); + record.Envelope.ShouldNotBeNull(); + record.Envelope.Destination.ShouldBe(new Uri($"signalr-client://localhost:{Port}/messages")); + record.Message.ShouldBeOfType() + .Name.ShouldBe("Hollywood Brown"); + } + + [Fact] + public async Task client_can_publish_using_authenticated_hub() + { + // This is an IHost that has the SignalR Client + // transport configured to connect to a SignalR + // server in the "theWebApp" IHost + using var client = await StartClientHost(); + + var tracked = await client + .TrackActivity() + .IncludeExternalTransports() + .AlsoTrack(theWebApp) + .Timeout(10.Seconds()) + .ExecuteAndWaitAsync(c => c.SendViaSignalRClient(clientUri, new ToSecond("Hollywood Brown"))); + + var record = tracked.Received.SingleRecord(); + record.ServiceName.ShouldBe("Server"); + record.Envelope.ShouldNotBeNull(); + record.Envelope.Destination.ShouldBe(new Uri("signalr://wolverine")); + record.Message.ShouldBeOfType() + .Name.ShouldBe("Hollywood Brown"); + + ReceivedJson.Count.ShouldBe(1); + } + + [Fact] + public async Task client_with_invalid_token_cannot_connect() + { + // This is an IHost that has the SignalR Client + // transport configured to connect to a SignalR + // server in the "theWebApp" IHost + using var client = await StartClientHost(accessToken: "last-years-token"); + + var tracked = await client + .TrackActivity() + .IncludeExternalTransports() + .AlsoTrack(theWebApp) + .Timeout(10.Seconds()) + .ExecuteAndWaitAsync(c => c.SendViaSignalRClient(clientUri, new ToSecond("Hollywood Brown"))); + + tracked.Received.Envelopes().ShouldBeEmpty(); + ReceivedJson.ShouldBeEmpty(); + } + + [Fact] + public async Task client_with_invalid_token_cannot_receive() + { + var client = await StartClientHost(accessToken: "last-years-token"); + + var tracked = await theWebApp + .TrackActivity() + .IncludeExternalTransports() + .AlsoTrack(client) + .Timeout(100.Milliseconds()) + .DoNotAssertOnExceptionsDetected() // We're not supposed to be able to receive, so don't throw + .PublishMessageAndWaitAsync(new FromSecond("Hollywood Brown")); + + tracked.Received.Envelopes().ShouldBeEmpty(); + ReceivedJson.ShouldBeEmpty(); + } +} + public class CustomWolverineHub(SignalRTransport endpoint, ILogger logger) : WolverineHub(endpoint) { public override Task OnConnectedAsync() @@ -78,3 +170,19 @@ public override Task ReceiveMessage(string json) return base.ReceiveMessage(json); } } + +[Authorize] +public class AuthenticatedWolverineHub(SignalRTransport endpoint, ILogger logger) : WolverineHub(endpoint) +{ + public override Task OnConnectedAsync() + { + logger.LogInformation("Client authenticated with ID {ConnectionId}", Context.ConnectionId); + return base.OnConnectedAsync(); + } + + public override Task ReceiveMessage(string json) + { + custom_hub_with_authentication.ReceivedJson.Add(json); + return base.ReceiveMessage(json); + } +} diff --git a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs index a9e0323fe..902ee4bf7 100644 --- a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs +++ b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs @@ -39,18 +39,34 @@ public SignalRClientEndpoint(Uri uri, SignalRClientTransport parent) : base(Tran public JsonSerializerOptions JsonOptions { get; set; } + public Func>> AccessTokenProvider { get; set; } + public Uri SignalRUri { get; } public override async ValueTask BuildListenerAsync(IWolverineRuntime runtime, IReceiver receiver) { Receiver = receiver; Pipeline = runtime.Pipeline; - _connection ??= new HubConnectionBuilder().WithAutomaticReconnect().WithUrl(SignalRUri).Build(); + _connection ??= new HubConnectionBuilder() + .WithAutomaticReconnect() + .WithUrl(SignalRUri, opts => + { + opts.AccessTokenProvider = AccessTokenProvider?.Invoke(runtime.Services); + }) + .Build(); _mapper ??= BuildCloudEventsMapper(runtime, JsonOptions); Logger = runtime.LoggerFactory.CreateLogger(); - await _connection.StartAsync(); + try + { + await _connection.StartAsync(); + } + catch (HttpRequestException ex) when (ex.StatusCode == System.Net.HttpStatusCode.Unauthorized) + { + Logger.LogError(ex, "Unable to connect to SignalR. Hub returned Unauthorized"); + //throw; // FIXME: Should probably have better handling for this + } _connection.On(SignalRTransport.DefaultOperation, [typeof(string)], (args => { @@ -93,7 +109,12 @@ internal async Task ReceiveAsync(string json) protected override ISender CreateSender(IWolverineRuntime runtime) { - _connection ??= new HubConnectionBuilder().WithUrl(SignalRUri).Build(); + _connection ??= new HubConnectionBuilder() + .WithUrl(SignalRUri, opts => + { + opts.AccessTokenProvider = AccessTokenProvider?.Invoke(runtime.Services); + }) + .Build(); _mapper ??= BuildCloudEventsMapper(runtime, JsonOptions); return this; } @@ -133,6 +154,7 @@ async ValueTask IListener.StopAsync() bool ISender.SupportsNativeScheduledSend => false; Uri ISender.Destination => Uri; + public Task PingAsync() { return Task.FromResult(true); diff --git a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs index c732eaadd..120c93719 100644 --- a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs +++ b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs @@ -15,9 +15,11 @@ public static class SignalRClientExtensions /// /// /// + /// Optionally provide a token to use for authentication against the SignalR hub /// public static Uri UseClientToSignalR(this WolverineOptions options, string url, - JsonSerializerOptions? jsonOptions = null) + JsonSerializerOptions? jsonOptions = null, + Func>>? accessTokenProvider = null) { var transport = options.Transports.GetOrCreate(); @@ -26,6 +28,10 @@ public static Uri UseClientToSignalR(this WolverineOptions options, string url, { endpoint.JsonOptions = jsonOptions; } + if (accessTokenProvider != null) + { + endpoint.AccessTokenProvider = accessTokenProvider; + } return endpoint.Uri; } @@ -37,9 +43,11 @@ public static Uri UseClientToSignalR(this WolverineOptions options, string url, /// /// /// Default is messages. Route pattern where you have mapped the WolverineHub + /// Optionally provide a token to use for authentication against the SignalR hub /// - public static Uri UseClientToSignalR(this WolverineOptions options, int port, string route = "messages") - => options.UseClientToSignalR($"http://localhost:{port}/{route}"); + public static Uri UseClientToSignalR(this WolverineOptions options, int port, string route = "messages", + Func>>? accessTokenProvider = null) + => options.UseClientToSignalR($"http://localhost:{port}/{route}", accessTokenProvider: accessTokenProvider); /// /// Send a message via a SignalR Client for the given server Uri in the format "http://localhost:[port]/[hub url]" @@ -73,10 +81,11 @@ public static ValueTask SendViaSignalRClient(this IMessageBus bus, string server /// /// /// - public static void ToSignalRWithClient(this IPublishToExpression publishing, int port, string relativeUrl = "messages") + /// Optionally provide a token to use for authentication against the SignalR hub + public static void ToSignalRWithClient(this IPublishToExpression publishing, int port, string relativeUrl = "messages", Func>>? accessTokenProvider = null) { var url = $"http://localhost:{port}/{relativeUrl}"; - publishing.ToSignalRWithClient(url); + publishing.ToSignalRWithClient(url, accessTokenProvider); } /// @@ -84,7 +93,8 @@ public static void ToSignalRWithClient(this IPublishToExpression publishing, int /// /// /// - public static void ToSignalRWithClient(this IPublishToExpression publishing, string url) + /// Optionally provide a token to use for authentication against the SignalR hub + public static void ToSignalRWithClient(this IPublishToExpression publishing, string url, Func>>? accessTokenProvider = null) { var rawUri = new Uri(url); if (!rawUri.IsAbsoluteUri) @@ -97,6 +107,11 @@ public static void ToSignalRWithClient(this IPublishToExpression publishing, str var transport = transports.GetOrCreate(); var endpoint = transport.Clients[uri]; + if (accessTokenProvider != null) + { + endpoint.AccessTokenProvider = accessTokenProvider; + } + publishing.To(uri); }