diff --git a/docs/guide/messaging/transports/signalr.md b/docs/guide/messaging/transports/signalr.md
index ff6ee6b70..281240671 100644
--- a/docs/guide/messaging/transports/signalr.md
+++ b/docs/guide/messaging/transports/signalr.md
@@ -594,6 +594,49 @@ public async Task receive_message_from_a_client()
*Conveniently enough as I write this documentation today using existing test code, Hollywood Brown had a huge
game last night. Go Chiefs!*
+### Authorization
+
+If you are connecting to a hub requiring authorization (for example using the `[Authorize]` attribute) you need to provide a token provider.
+
+
+
+```cs
+var host = await Host.CreateDefaultBuilder()
+ .UseWolverine(opts =>
+ {
+ opts.ServiceName = serviceName;
+
+ // Configure a client with an access token provider. You get an instance of `IServiceProvider`
+ // if you need access to additional services, for example accessing `IConfiguration`
+ opts.UseClientToSignalR(Port, accessTokenProvider: (sp) => () => Task.FromResult(accessToken));
+
+ opts.Publish(x =>
+ {
+ x.MessagesImplementing();
+ x.ToSignalRWithClient(Port);
+ });
+
+ opts.Publish(x =>
+ {
+ x.MessagesImplementing();
+
+ // You can also configure the access token provider when configuring
+ // the message publishing. Last configuration wins and applies to the
+ // client URL, *not* the message type
+ x.ToSignalRWithClient(Port, accessTokenProvider: (sp) => () =>
+ {
+ var configuration = sp.GetRequiredService();
+ var configuredToken = configuration.GetValue("SignalR:AccessToken")
+ // Fall back to the token passed in when testing
+ ?? accessToken;
+ return Task.FromResult(configuredToken);
+ });
+ });
+ }).StartAsync();
+```
+snippet source | anchor
+
+
## Web Socket "Sagas"
::: info
diff --git a/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs b/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs
index b971af240..5adb274c1 100644
--- a/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs
+++ b/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs
@@ -1,10 +1,13 @@
-using System.Diagnostics;
+using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
-using Microsoft.AspNetCore.SignalR;
+using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
-using Wolverine.Runtime;
+using Microsoft.Extensions.Options;
+using System.Diagnostics;
+using System.Security.Claims;
+using System.Text.Encodings.Web;
using Wolverine.SignalR.Client;
using Wolverine.Util;
@@ -132,6 +135,19 @@ public async Task InitializeAsync()
opts.ListenLocalhost(Port);
});
+ builder.Services
+ .AddAuthentication("TestAuthScheme")
+ .AddScheme("TestAuthScheme", null, null);
+
+ builder.Services.AddAuthorizationCore(options =>
+ {
+ options.AddPolicy("TestToken", policyBuilder =>
+ {
+ policyBuilder.AuthenticationSchemes.Add("TestAuthScheme");
+ policyBuilder.RequireAuthenticatedUser();
+ });
+ });
+
#region sample_custom_signalr_hub
builder.Services.AddSignalR();
builder.Host.UseWolverine(opts =>
@@ -162,25 +178,42 @@ public async Task InitializeAsync()
// This starts up a new host to act as a client to the SignalR
// server for testing
- public async Task StartClientHost(string serviceName = "Client")
+ public async Task StartClientHost(string serviceName = "Client", string accessToken = "supersecrettoken")
{
+ #region sample_signalr_authentication
var host = await Host.CreateDefaultBuilder()
.UseWolverine(opts =>
{
opts.ServiceName = serviceName;
- opts.UseClientToSignalR(Port);
-
- opts.PublishMessage().ToSignalRWithClient(Port);
-
- opts.PublishMessage().ToSignalRWithClient(Port);
+ // Configure a client with an access token provider. You get an instance of `IServiceProvider`
+ // if you need access to additional services, for example accessing `IConfiguration`
+ opts.UseClientToSignalR(Port, accessTokenProvider: (sp) => () => Task.FromResult(accessToken));
opts.Publish(x =>
{
x.MessagesImplementing();
x.ToSignalRWithClient(Port);
});
+
+ opts.Publish(x =>
+ {
+ x.MessagesImplementing();
+
+ // You can also configure the access token provider when configuring
+ // the message publishing. Last configuration wins and applies to the
+ // client URL, *not* the message type
+ x.ToSignalRWithClient(Port, accessTokenProvider: (sp) => () =>
+ {
+ var configuration = sp.GetRequiredService();
+ var configuredToken = configuration.GetValue("SignalR:AccessToken")
+ // Fall back to the token passed in when testing
+ ?? accessToken;
+ return Task.FromResult(configuredToken);
+ });
+ });
}).StartAsync();
+ #endregion
_clientHosts.Add(host);
@@ -203,6 +236,9 @@ public record ToFirst(string Name) : WebSocketMessage;
public record FromFirst(string Name) : WebSocketMessage;
public record ToSecond(string Name) : WebSocketMessage;
public record FromSecond(string Name) : WebSocketMessage;
+public interface AuthenticatedWebSocketMessage : WebSocketMessage
+{
+}
public static class WebSocketMessageHandler
{
@@ -212,3 +248,30 @@ public static class WebSocketMessageHandler
public static void Handle(FromSecond m) => Debug.WriteLine("Got " + m);
}
+internal class TestAuthenticationHandler : AuthenticationHandler
+{
+ public TestAuthenticationHandler(
+ IOptionsMonitor options,
+ Microsoft.Extensions.Logging.ILoggerFactory logger,
+ UrlEncoder encoder)
+ : base(options, logger, encoder)
+ {
+ }
+
+ protected override Task HandleAuthenticateAsync()
+ {
+ var authToken = Context.Request.Headers.Authorization.ToString().Split(" ").Last();
+ if (authToken != "supersecrettoken")
+ return Task.FromResult(AuthenticateResult.Fail("Invalid token"));
+
+ var identity = new ClaimsIdentity([new Claim(ClaimTypes.NameIdentifier, "wolverine")], Scheme.Name);
+ var principal = new ClaimsPrincipal(identity);
+ var ticket = new AuthenticationTicket(principal, Scheme.Name);
+
+ return Task.FromResult(AuthenticateResult.Success(ticket));
+ }
+}
+
+public class TestAuthenticationOptions : AuthenticationSchemeOptions
+{
+}
diff --git a/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs b/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs
index 39ec94fbd..6e5b39ff6 100644
--- a/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs
+++ b/src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs
@@ -1,5 +1,5 @@
using JasperFx.Core;
-using Microsoft.AspNetCore.SignalR;
+using Microsoft.AspNetCore.Authorization;
using Microsoft.Extensions.Logging;
using Shouldly;
using Wolverine.SignalR.Client;
@@ -10,7 +10,7 @@ namespace Wolverine.SignalR.Tests;
public class custom_hub : WebSocketTestContextWithCustomHub
{
- public static List ReceivedJson = new();
+ public static List ReceivedJson { get; } = new();
public override async Task DisposeAsync()
{
@@ -64,6 +64,98 @@ public async Task publishing_from_client_uses_custom_hub()
}
}
+public class custom_hub_with_authentication : WebSocketTestContextWithCustomHub
+{
+ public static List ReceivedJson { get; } = new();
+
+ public override async Task DisposeAsync()
+ {
+ ReceivedJson.Clear();
+ await base.DisposeAsync();
+ }
+
+ [Fact]
+ public async Task client_can_receive_from_authenticated_hub()
+ {
+ var client = await StartClientHost();
+
+ var tracked = await theWebApp
+ .TrackActivity()
+ .IncludeExternalTransports()
+ .AlsoTrack(client)
+ .Timeout(10.Seconds())
+ .SendMessageAndWaitAsync(new FromSecond("Hollywood Brown"));
+
+ var record = tracked.Received.SingleRecord();
+ record.ServiceName.ShouldBe("Client");
+ record.Envelope.ShouldNotBeNull();
+ record.Envelope.Destination.ShouldBe(new Uri($"signalr-client://localhost:{Port}/messages"));
+ record.Message.ShouldBeOfType()
+ .Name.ShouldBe("Hollywood Brown");
+ }
+
+ [Fact]
+ public async Task client_can_publish_using_authenticated_hub()
+ {
+ // This is an IHost that has the SignalR Client
+ // transport configured to connect to a SignalR
+ // server in the "theWebApp" IHost
+ using var client = await StartClientHost();
+
+ var tracked = await client
+ .TrackActivity()
+ .IncludeExternalTransports()
+ .AlsoTrack(theWebApp)
+ .Timeout(10.Seconds())
+ .ExecuteAndWaitAsync(c => c.SendViaSignalRClient(clientUri, new ToSecond("Hollywood Brown")));
+
+ var record = tracked.Received.SingleRecord();
+ record.ServiceName.ShouldBe("Server");
+ record.Envelope.ShouldNotBeNull();
+ record.Envelope.Destination.ShouldBe(new Uri("signalr://wolverine"));
+ record.Message.ShouldBeOfType()
+ .Name.ShouldBe("Hollywood Brown");
+
+ ReceivedJson.Count.ShouldBe(1);
+ }
+
+ [Fact]
+ public async Task client_with_invalid_token_cannot_connect()
+ {
+ // This is an IHost that has the SignalR Client
+ // transport configured to connect to a SignalR
+ // server in the "theWebApp" IHost
+ using var client = await StartClientHost(accessToken: "last-years-token");
+
+ var tracked = await client
+ .TrackActivity()
+ .IncludeExternalTransports()
+ .AlsoTrack(theWebApp)
+ .Timeout(10.Seconds())
+ .ExecuteAndWaitAsync(c => c.SendViaSignalRClient(clientUri, new ToSecond("Hollywood Brown")));
+
+ tracked.Received.Envelopes().ShouldBeEmpty();
+ ReceivedJson.ShouldBeEmpty();
+ }
+
+ [Fact]
+ public async Task client_with_invalid_token_cannot_receive()
+ {
+ var client = await StartClientHost(accessToken: "last-years-token");
+
+ var tracked = await theWebApp
+ .TrackActivity()
+ .IncludeExternalTransports()
+ .AlsoTrack(client)
+ .Timeout(100.Milliseconds())
+ .DoNotAssertOnExceptionsDetected() // We're not supposed to be able to receive, so don't throw
+ .PublishMessageAndWaitAsync(new FromSecond("Hollywood Brown"));
+
+ tracked.Received.Envelopes().ShouldBeEmpty();
+ ReceivedJson.ShouldBeEmpty();
+ }
+}
+
public class CustomWolverineHub(SignalRTransport endpoint, ILogger logger) : WolverineHub(endpoint)
{
public override Task OnConnectedAsync()
@@ -78,3 +170,19 @@ public override Task ReceiveMessage(string json)
return base.ReceiveMessage(json);
}
}
+
+[Authorize]
+public class AuthenticatedWolverineHub(SignalRTransport endpoint, ILogger logger) : WolverineHub(endpoint)
+{
+ public override Task OnConnectedAsync()
+ {
+ logger.LogInformation("Client authenticated with ID {ConnectionId}", Context.ConnectionId);
+ return base.OnConnectedAsync();
+ }
+
+ public override Task ReceiveMessage(string json)
+ {
+ custom_hub_with_authentication.ReceivedJson.Add(json);
+ return base.ReceiveMessage(json);
+ }
+}
diff --git a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs
index a9e0323fe..902ee4bf7 100644
--- a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs
+++ b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientEndpoint.cs
@@ -39,18 +39,34 @@ public SignalRClientEndpoint(Uri uri, SignalRClientTransport parent) : base(Tran
public JsonSerializerOptions JsonOptions { get; set; }
+ public Func>> AccessTokenProvider { get; set; }
+
public Uri SignalRUri { get; }
public override async ValueTask BuildListenerAsync(IWolverineRuntime runtime, IReceiver receiver)
{
Receiver = receiver;
Pipeline = runtime.Pipeline;
- _connection ??= new HubConnectionBuilder().WithAutomaticReconnect().WithUrl(SignalRUri).Build();
+ _connection ??= new HubConnectionBuilder()
+ .WithAutomaticReconnect()
+ .WithUrl(SignalRUri, opts =>
+ {
+ opts.AccessTokenProvider = AccessTokenProvider?.Invoke(runtime.Services);
+ })
+ .Build();
_mapper ??= BuildCloudEventsMapper(runtime, JsonOptions);
Logger = runtime.LoggerFactory.CreateLogger();
- await _connection.StartAsync();
+ try
+ {
+ await _connection.StartAsync();
+ }
+ catch (HttpRequestException ex) when (ex.StatusCode == System.Net.HttpStatusCode.Unauthorized)
+ {
+ Logger.LogError(ex, "Unable to connect to SignalR. Hub returned Unauthorized");
+ //throw; // FIXME: Should probably have better handling for this
+ }
_connection.On(SignalRTransport.DefaultOperation, [typeof(string)], (args =>
{
@@ -93,7 +109,12 @@ internal async Task ReceiveAsync(string json)
protected override ISender CreateSender(IWolverineRuntime runtime)
{
- _connection ??= new HubConnectionBuilder().WithUrl(SignalRUri).Build();
+ _connection ??= new HubConnectionBuilder()
+ .WithUrl(SignalRUri, opts =>
+ {
+ opts.AccessTokenProvider = AccessTokenProvider?.Invoke(runtime.Services);
+ })
+ .Build();
_mapper ??= BuildCloudEventsMapper(runtime, JsonOptions);
return this;
}
@@ -133,6 +154,7 @@ async ValueTask IListener.StopAsync()
bool ISender.SupportsNativeScheduledSend => false;
Uri ISender.Destination => Uri;
+
public Task PingAsync()
{
return Task.FromResult(true);
diff --git a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs
index c732eaadd..120c93719 100644
--- a/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs
+++ b/src/Transports/SignalR/Wolverine.SignalR/Client/SignalRClientExtensions.cs
@@ -15,9 +15,11 @@ public static class SignalRClientExtensions
///
///
///
+ /// Optionally provide a token to use for authentication against the SignalR hub
///
public static Uri UseClientToSignalR(this WolverineOptions options, string url,
- JsonSerializerOptions? jsonOptions = null)
+ JsonSerializerOptions? jsonOptions = null,
+ Func>>? accessTokenProvider = null)
{
var transport = options.Transports.GetOrCreate();
@@ -26,6 +28,10 @@ public static Uri UseClientToSignalR(this WolverineOptions options, string url,
{
endpoint.JsonOptions = jsonOptions;
}
+ if (accessTokenProvider != null)
+ {
+ endpoint.AccessTokenProvider = accessTokenProvider;
+ }
return endpoint.Uri;
}
@@ -37,9 +43,11 @@ public static Uri UseClientToSignalR(this WolverineOptions options, string url,
///
///
/// Default is messages. Route pattern where you have mapped the WolverineHub
+ /// Optionally provide a token to use for authentication against the SignalR hub
///
- public static Uri UseClientToSignalR(this WolverineOptions options, int port, string route = "messages")
- => options.UseClientToSignalR($"http://localhost:{port}/{route}");
+ public static Uri UseClientToSignalR(this WolverineOptions options, int port, string route = "messages",
+ Func>>? accessTokenProvider = null)
+ => options.UseClientToSignalR($"http://localhost:{port}/{route}", accessTokenProvider: accessTokenProvider);
///
/// Send a message via a SignalR Client for the given server Uri in the format "http://localhost:[port]/[hub url]"
@@ -73,10 +81,11 @@ public static ValueTask SendViaSignalRClient(this IMessageBus bus, string server
///
///
///
- public static void ToSignalRWithClient(this IPublishToExpression publishing, int port, string relativeUrl = "messages")
+ /// Optionally provide a token to use for authentication against the SignalR hub
+ public static void ToSignalRWithClient(this IPublishToExpression publishing, int port, string relativeUrl = "messages", Func>>? accessTokenProvider = null)
{
var url = $"http://localhost:{port}/{relativeUrl}";
- publishing.ToSignalRWithClient(url);
+ publishing.ToSignalRWithClient(url, accessTokenProvider);
}
///
@@ -84,7 +93,8 @@ public static void ToSignalRWithClient(this IPublishToExpression publishing, int
///
///
///
- public static void ToSignalRWithClient(this IPublishToExpression publishing, string url)
+ /// Optionally provide a token to use for authentication against the SignalR hub
+ public static void ToSignalRWithClient(this IPublishToExpression publishing, string url, Func>>? accessTokenProvider = null)
{
var rawUri = new Uri(url);
if (!rawUri.IsAbsoluteUri)
@@ -97,6 +107,11 @@ public static void ToSignalRWithClient(this IPublishToExpression publishing, str
var transport = transports.GetOrCreate();
var endpoint = transport.Clients[uri];
+ if (accessTokenProvider != null)
+ {
+ endpoint.AccessTokenProvider = accessTokenProvider;
+ }
+
publishing.To(uri);
}