diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 24c675a5aa8f..5e914b31e29e 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -726,7 +726,26 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio var hubInvocationArgumentPointer = 0; for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) { - if (hubMethodInvocationMessage.Arguments?.Length > hubInvocationArgumentPointer && + // populate the synthetic arguments first + if (descriptor.IsServiceArgument(parameterPointer)) + { + arguments[parameterPointer] = descriptor.GetService(scope.ServiceProvider, parameterPointer, descriptor.OriginalParameterTypes[parameterPointer]); + } + else if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken)) + { + cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + arguments[parameterPointer] = cts.Token; + } + else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true)) + { + Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds![streamPointer]); + var itemType = descriptor.StreamingParameters![streamPointer]; + arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer], + itemType, descriptor.OriginalParameterTypes[parameterPointer]); + + streamPointer++; + } + else if (hubMethodInvocationMessage.Arguments?.Length > hubInvocationArgumentPointer && (hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer] == null || descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer]?.GetType()))) { @@ -736,29 +755,8 @@ private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocatio } else { - if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken)) - { - cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); - arguments[parameterPointer] = cts.Token; - } - else if (descriptor.IsServiceArgument(parameterPointer)) - { - arguments[parameterPointer] = descriptor.GetService(scope.ServiceProvider, parameterPointer, descriptor.OriginalParameterTypes[parameterPointer]); - } - else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true)) - { - Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds![streamPointer]); - var itemType = descriptor.StreamingParameters![streamPointer]; - arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer], - itemType, descriptor.OriginalParameterTypes[parameterPointer]); - - streamPointer++; - } - else - { - // This should never happen - Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{descriptor.MethodExecutor.MethodInfo.Name}'."); - } + // This should never happen + Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{descriptor.MethodExecutor.MethodInfo.Name}'."); } } } @@ -896,4 +894,4 @@ private static void SetActivityError(Activity? activity, Exception ex) activity?.SetTag("error.type", ex.GetType().FullName); activity?.SetStatus(ActivityStatusCode.Error); } -} +} \ No newline at end of file diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index 091ffabaad9f..90fabf90265b 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -1381,6 +1381,11 @@ public async Task ServicesAndParams(int value, [FromService] Service1 servi } return total + value; } + + public int ServiceWithStringAttribute([FromService] Service1 service, string value) + { + return 115; + } public int ServiceWithoutAttribute(Service1 service) { @@ -1464,4 +1469,4 @@ public override async Task OnConnectedAsync() await Clients.Client(id).SendAsync("Test", 1); } } -} +} \ No newline at end of file diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index b6e37225ef4e..486f6c53ea02 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -4679,6 +4679,24 @@ public async Task HubMethodCanInjectService() Assert.True(Assert.IsType(res.Result)); } } + + // Regression test for https://github.com/dotnet/aspnetcore/issues/61491 + [Fact] + public async Task HubMethodCanInjectServiceWithNullParameter() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(provider => + { + provider.AddSingleton(); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + var res = await client.InvokeAsync(nameof(ServicesHub.ServiceWithStringAttribute),(string)null ).DefaultTimeout(); + Assert.Equal(115L, res.Result); + } + } [Fact] public async Task HubMethodCanInjectMultipleServices() @@ -5459,4 +5477,4 @@ public static async Task> ReadAllAsync(this IAsyncEnumerable