Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/guide/messaging/transports/signalr.md
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,49 @@ public async Task receive_message_from_a_client()
*Conveniently enough as I write this documentation today using existing test code, Hollywood Brown had a huge
game last night. Go Chiefs!*

### Authorization

If you are connecting to a hub requiring authorization (for example using the `[Authorize]` attribute) you need to provide a token provider.

<!-- snippet: sample_signalr_authentication -->
<a id='snippet-sample_signalr_authentication'></a>
```cs
var host = await Host.CreateDefaultBuilder()
.UseWolverine(opts =>
{
opts.ServiceName = serviceName;

// Configure a client with an access token provider. You get an instance of `IServiceProvider`
// if you need access to additional services, for example accessing `IConfiguration`
opts.UseClientToSignalR(Port, accessTokenProvider: (sp) => () => Task.FromResult<string?>(accessToken));

opts.Publish(x =>
{
x.MessagesImplementing<WebSocketMessage>();
x.ToSignalRWithClient(Port);
});

opts.Publish(x =>
{
x.MessagesImplementing<AuthenticatedWebSocketMessage>();

// You can also configure the access token provider when configuring
// the message publishing. Last configuration wins and applies to the
// client URL, *not* the message type
x.ToSignalRWithClient(Port, accessTokenProvider: (sp) => () =>
{
var configuration = sp.GetRequiredService<IConfiguration>();
var configuredToken = configuration.GetValue<string?>("SignalR:AccessToken")
// Fall back to the token passed in when testing
?? accessToken;
return Task.FromResult<string?>(configuredToken);
});
});
}).StartAsync();
```
<sup><a href='https://github.com/JasperFx/wolverine/blob/main/src/Transports/SignalR/Wolverine.SignalR.Tests/WebSocketTestContext.cs#L183-L216' title='Snippet source file'>snippet source</a> | <a href='#snippet-sample_signalr_authentication' title='Start of snippet'>anchor</a></sup>
<!-- endSnippet -->

## Web Socket "Sagas"

::: info
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
using System.Diagnostics;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Wolverine.Runtime;
using Microsoft.Extensions.Options;
using System.Diagnostics;
using System.Security.Claims;
using System.Text.Encodings.Web;
using Wolverine.SignalR.Client;
using Wolverine.Util;

Expand Down Expand Up @@ -132,6 +135,19 @@ public async Task InitializeAsync()
opts.ListenLocalhost(Port);
});

builder.Services
.AddAuthentication("TestAuthScheme")
.AddScheme<TestAuthenticationOptions, TestAuthenticationHandler>("TestAuthScheme", null, null);

builder.Services.AddAuthorizationCore(options =>
{
options.AddPolicy("TestToken", policyBuilder =>
{
policyBuilder.AuthenticationSchemes.Add("TestAuthScheme");
policyBuilder.RequireAuthenticatedUser();
});
});

#region sample_custom_signalr_hub
builder.Services.AddSignalR();
builder.Host.UseWolverine(opts =>
Expand Down Expand Up @@ -162,25 +178,42 @@ public async Task InitializeAsync()

// This starts up a new host to act as a client to the SignalR
// server for testing
public async Task<IHost> StartClientHost(string serviceName = "Client")
public async Task<IHost> StartClientHost(string serviceName = "Client", string accessToken = "supersecrettoken")
{
#region sample_signalr_authentication
var host = await Host.CreateDefaultBuilder()
.UseWolverine(opts =>
{
opts.ServiceName = serviceName;

opts.UseClientToSignalR(Port);

opts.PublishMessage<ToFirst>().ToSignalRWithClient(Port);

opts.PublishMessage<RequiresResponse>().ToSignalRWithClient(Port);
// Configure a client with an access token provider. You get an instance of `IServiceProvider`
// if you need access to additional services, for example accessing `IConfiguration`
opts.UseClientToSignalR(Port, accessTokenProvider: (sp) => () => Task.FromResult<string?>(accessToken));

opts.Publish(x =>
{
x.MessagesImplementing<WebSocketMessage>();
x.ToSignalRWithClient(Port);
});

opts.Publish(x =>
{
x.MessagesImplementing<AuthenticatedWebSocketMessage>();

// You can also configure the access token provider when configuring
// the message publishing. Last configuration wins and applies to the
// client URL, *not* the message type
x.ToSignalRWithClient(Port, accessTokenProvider: (sp) => () =>
{
var configuration = sp.GetRequiredService<IConfiguration>();
var configuredToken = configuration.GetValue<string?>("SignalR:AccessToken")
// Fall back to the token passed in when testing
?? accessToken;
return Task.FromResult<string?>(configuredToken);
});
});
}).StartAsync();
#endregion

_clientHosts.Add(host);

Expand All @@ -203,6 +236,9 @@ public record ToFirst(string Name) : WebSocketMessage;
public record FromFirst(string Name) : WebSocketMessage;
public record ToSecond(string Name) : WebSocketMessage;
public record FromSecond(string Name) : WebSocketMessage;
public interface AuthenticatedWebSocketMessage : WebSocketMessage
{
}

public static class WebSocketMessageHandler
{
Expand All @@ -212,3 +248,30 @@ public static class WebSocketMessageHandler
public static void Handle(FromSecond m) => Debug.WriteLine("Got " + m);
}

internal class TestAuthenticationHandler : AuthenticationHandler<TestAuthenticationOptions>
{
public TestAuthenticationHandler(
IOptionsMonitor<TestAuthenticationOptions> options,
Microsoft.Extensions.Logging.ILoggerFactory logger,
UrlEncoder encoder)
: base(options, logger, encoder)
{
}

protected override Task<AuthenticateResult> HandleAuthenticateAsync()
{
var authToken = Context.Request.Headers.Authorization.ToString().Split(" ").Last();
if (authToken != "supersecrettoken")
return Task.FromResult(AuthenticateResult.Fail("Invalid token"));

var identity = new ClaimsIdentity([new Claim(ClaimTypes.NameIdentifier, "wolverine")], Scheme.Name);
var principal = new ClaimsPrincipal(identity);
var ticket = new AuthenticationTicket(principal, Scheme.Name);

return Task.FromResult(AuthenticateResult.Success(ticket));
}
}

public class TestAuthenticationOptions : AuthenticationSchemeOptions
{
}
112 changes: 110 additions & 2 deletions src/Transports/SignalR/Wolverine.SignalR.Tests/custom_hub.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using JasperFx.Core;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.Authorization;
using Microsoft.Extensions.Logging;
using Shouldly;
using Wolverine.SignalR.Client;
Expand All @@ -10,7 +10,7 @@ namespace Wolverine.SignalR.Tests;

public class custom_hub : WebSocketTestContextWithCustomHub<CustomWolverineHub>
{
public static List<string> ReceivedJson = new();
public static List<string> ReceivedJson { get; } = new();

public override async Task DisposeAsync()
{
Expand Down Expand Up @@ -64,6 +64,98 @@ public async Task publishing_from_client_uses_custom_hub()
}
}

public class custom_hub_with_authentication : WebSocketTestContextWithCustomHub<AuthenticatedWolverineHub>
{
public static List<string> ReceivedJson { get; } = new();

public override async Task DisposeAsync()
{
ReceivedJson.Clear();
await base.DisposeAsync();
}

[Fact]
public async Task client_can_receive_from_authenticated_hub()
{
var client = await StartClientHost();

var tracked = await theWebApp
.TrackActivity()
.IncludeExternalTransports()
.AlsoTrack(client)
.Timeout(10.Seconds())
.SendMessageAndWaitAsync(new FromSecond("Hollywood Brown"));

var record = tracked.Received.SingleRecord<FromSecond>();
record.ServiceName.ShouldBe("Client");
record.Envelope.ShouldNotBeNull();
record.Envelope.Destination.ShouldBe(new Uri($"signalr-client://localhost:{Port}/messages"));
record.Message.ShouldBeOfType<FromSecond>()
.Name.ShouldBe("Hollywood Brown");
}

[Fact]
public async Task client_can_publish_using_authenticated_hub()
{
// This is an IHost that has the SignalR Client
// transport configured to connect to a SignalR
// server in the "theWebApp" IHost
using var client = await StartClientHost();

var tracked = await client
.TrackActivity()
.IncludeExternalTransports()
.AlsoTrack(theWebApp)
.Timeout(10.Seconds())
.ExecuteAndWaitAsync(c => c.SendViaSignalRClient(clientUri, new ToSecond("Hollywood Brown")));

var record = tracked.Received.SingleRecord<ToSecond>();
record.ServiceName.ShouldBe("Server");
record.Envelope.ShouldNotBeNull();
record.Envelope.Destination.ShouldBe(new Uri("signalr://wolverine"));
record.Message.ShouldBeOfType<ToSecond>()
.Name.ShouldBe("Hollywood Brown");

ReceivedJson.Count.ShouldBe(1);
}

[Fact]
public async Task client_with_invalid_token_cannot_connect()
{
// This is an IHost that has the SignalR Client
// transport configured to connect to a SignalR
// server in the "theWebApp" IHost
using var client = await StartClientHost(accessToken: "last-years-token");

var tracked = await client
.TrackActivity()
.IncludeExternalTransports()
.AlsoTrack(theWebApp)
.Timeout(10.Seconds())
.ExecuteAndWaitAsync(c => c.SendViaSignalRClient(clientUri, new ToSecond("Hollywood Brown")));

tracked.Received.Envelopes().ShouldBeEmpty();
ReceivedJson.ShouldBeEmpty();
}

[Fact]
public async Task client_with_invalid_token_cannot_receive()
{
var client = await StartClientHost(accessToken: "last-years-token");

var tracked = await theWebApp
.TrackActivity()
.IncludeExternalTransports()
.AlsoTrack(client)
.Timeout(100.Milliseconds())
.DoNotAssertOnExceptionsDetected() // We're not supposed to be able to receive, so don't throw
.PublishMessageAndWaitAsync(new FromSecond("Hollywood Brown"));

tracked.Received.Envelopes().ShouldBeEmpty();
ReceivedJson.ShouldBeEmpty();
}
}

public class CustomWolverineHub(SignalRTransport endpoint, ILogger<CustomWolverineHub> logger) : WolverineHub(endpoint)
{
public override Task OnConnectedAsync()
Expand All @@ -78,3 +170,19 @@ public override Task ReceiveMessage(string json)
return base.ReceiveMessage(json);
}
}

[Authorize]
public class AuthenticatedWolverineHub(SignalRTransport endpoint, ILogger<CustomWolverineHub> logger) : WolverineHub(endpoint)
{
public override Task OnConnectedAsync()
{
logger.LogInformation("Client authenticated with ID {ConnectionId}", Context.ConnectionId);
return base.OnConnectedAsync();
}

public override Task ReceiveMessage(string json)
{
custom_hub_with_authentication.ReceivedJson.Add(json);
return base.ReceiveMessage(json);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,34 @@ public SignalRClientEndpoint(Uri uri, SignalRClientTransport parent) : base(Tran

public JsonSerializerOptions JsonOptions { get; set; }

public Func<IServiceProvider, Func<Task<string?>>> AccessTokenProvider { get; set; }

public Uri SignalRUri { get; }

public override async ValueTask<IListener> BuildListenerAsync(IWolverineRuntime runtime, IReceiver receiver)
{
Receiver = receiver;
Pipeline = runtime.Pipeline;
_connection ??= new HubConnectionBuilder().WithAutomaticReconnect().WithUrl(SignalRUri).Build();
_connection ??= new HubConnectionBuilder()
.WithAutomaticReconnect()
.WithUrl(SignalRUri, opts =>
{
opts.AccessTokenProvider = AccessTokenProvider?.Invoke(runtime.Services);
})
.Build();
_mapper ??= BuildCloudEventsMapper(runtime, JsonOptions);

Logger = runtime.LoggerFactory.CreateLogger<SignalRClientEndpoint>();

await _connection.StartAsync();
try
{
await _connection.StartAsync();
}
catch (HttpRequestException ex) when (ex.StatusCode == System.Net.HttpStatusCode.Unauthorized)
{
Logger.LogError(ex, "Unable to connect to SignalR. Hub returned Unauthorized");
//throw; // FIXME: Should probably have better handling for this
}

_connection.On(SignalRTransport.DefaultOperation, [typeof(string)], (args =>
{
Expand Down Expand Up @@ -93,7 +109,12 @@ internal async Task ReceiveAsync(string json)

protected override ISender CreateSender(IWolverineRuntime runtime)
{
_connection ??= new HubConnectionBuilder().WithUrl(SignalRUri).Build();
_connection ??= new HubConnectionBuilder()
.WithUrl(SignalRUri, opts =>
{
opts.AccessTokenProvider = AccessTokenProvider?.Invoke(runtime.Services);
})
.Build();
_mapper ??= BuildCloudEventsMapper(runtime, JsonOptions);
return this;
}
Expand Down Expand Up @@ -133,6 +154,7 @@ async ValueTask IListener.StopAsync()

bool ISender.SupportsNativeScheduledSend => false;
Uri ISender.Destination => Uri;

public Task<bool> PingAsync()
{
return Task.FromResult(true);
Expand Down
Loading
Loading