diff --git a/src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs b/src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs index 55d161d9a..574a172a8 100644 --- a/src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs +++ b/src/Dapr.Actors.AspNetCore/ActorsEndpointRouteBuilderExtensions.cs @@ -103,7 +103,7 @@ private static IEndpointConventionBuilder MapActorMethodEndpoint(this IEndpointR try { - var (header, body) = await runtime.DispatchWithRemotingAsync(actorTypeName, actorId, methodName, daprActorheader, context.Request.Body); + var (header, body) = await runtime.DispatchWithRemotingAsync(actorTypeName, actorId, methodName, daprActorheader, context.Request.Body, context.RequestAborted); // Item 1 is header , Item 2 is body if (header != string.Empty) @@ -112,14 +112,14 @@ private static IEndpointConventionBuilder MapActorMethodEndpoint(this IEndpointR context.Response.Headers[Constants.ErrorResponseHeaderName] = header; // add error header } - await context.Response.Body.WriteAsync(body, 0, body.Length); // add response message body + await context.Response.Body.WriteAsync(body, 0, body.Length, context.RequestAborted); // add response message body } catch (Exception ex) { var (header, body) = CreateExceptionResponseMessage(ex); context.Response.Headers[Constants.ErrorResponseHeaderName] = header; - await context.Response.Body.WriteAsync(body, 0, body.Length); + await context.Response.Body.WriteAsync(body, 0, body.Length, context.RequestAborted); } finally { @@ -130,7 +130,7 @@ private static IEndpointConventionBuilder MapActorMethodEndpoint(this IEndpointR { try { - await runtime.DispatchWithoutRemotingAsync(actorTypeName, actorId, methodName, context.Request.Body, context.Response.Body); + await runtime.DispatchWithoutRemotingAsync(actorTypeName, actorId, methodName, context.Request.Body, context.Response.Body, context.RequestAborted); } finally { diff --git a/src/Dapr.Actors/Runtime/ActorManager.cs b/src/Dapr.Actors/Runtime/ActorManager.cs index b7ee3bf3e..d766cd485 100644 --- a/src/Dapr.Actors/Runtime/ActorManager.cs +++ b/src/Dapr.Actors/Runtime/ActorManager.cs @@ -148,16 +148,16 @@ async Task RequestFunc(Actor actor, CancellationToken ct) var parameters = methodInfo.GetParameters(); dynamic awaitable; - if (parameters.Length == 0) + if (parameters.Length == 0 || (parameters.Length == 1 && parameters[0].ParameterType == typeof(CancellationToken))) { - awaitable = methodInfo.Invoke(actor, null); + awaitable = methodInfo.Invoke(actor, parameters.Length == 0 ? null : new object[] { ct }); } - else if (parameters.Length == 1) + else if (parameters.Length == 1 || (parameters.Length == 2 && parameters[1].ParameterType == typeof(CancellationToken))) { // deserialize using stream. var type = parameters[0].ParameterType; var deserializedType = await JsonSerializer.DeserializeAsync(requestBodyStream, type, jsonSerializerOptions); - awaitable = methodInfo.Invoke(actor, new object[] { deserializedType }); + awaitable = methodInfo.Invoke(actor, parameters.Length == 1 ? new object[] { deserializedType } : new object[] { deserializedType, ct }); } else { diff --git a/test/Dapr.Actors.Test/Runtime/ActorRuntimeTests.cs b/test/Dapr.Actors.Test/Runtime/ActorRuntimeTests.cs index 52ae4aa7b..c74d0b754 100644 --- a/test/Dapr.Actors.Test/Runtime/ActorRuntimeTests.cs +++ b/test/Dapr.Actors.Test/Runtime/ActorRuntimeTests.cs @@ -27,6 +27,7 @@ namespace Dapr.Actors.Test using Xunit; using Dapr.Actors.Client; using System.Reflection; + using System.Threading; public sealed class ActorRuntimeTests { @@ -109,6 +110,111 @@ public async Task NoActivateMessageFromRuntime() Assert.Contains(actorType.Name, runtime.RegisteredActors.Select(a => a.Type.ActorTypeName), StringComparer.InvariantCulture); } + public interface INotRemotedActor : IActor + { + Task NoArgumentsAsync(); + + Task NoArgumentsWithCancellationAsync(CancellationToken cancellationToken = default); + + Task SingleArgumentAsync(bool arg); + + Task SingleArgumentWithCancellationAsync(bool arg, CancellationToken cancellationToken = default); + } + + public sealed class NotRemotedActor : Actor, INotRemotedActor + { + public NotRemotedActor(ActorHost host) + : base(host) + { + } + + public Task NoArgumentsAsync() + { + return Task.FromResult(nameof(NoArgumentsAsync)); + } + + public Task NoArgumentsWithCancellationAsync(CancellationToken cancellationToken = default) + { + return Task.FromResult(nameof(NoArgumentsWithCancellationAsync)); + } + + public Task SingleArgumentAsync(bool arg) + { + return Task.FromResult(nameof(SingleArgumentAsync)); + } + + public Task SingleArgumentWithCancellationAsync(bool arg, CancellationToken cancellationToken = default) + { + return Task.FromResult(nameof(SingleArgumentWithCancellationAsync)); + } + } + + public async Task InvokeMethod(string methodName, object arg = null) where T : Actor + { + var options = new ActorRuntimeOptions(); + + options.Actors.RegisterActor(); + + var runtime = new ActorRuntime(options, loggerFactory, activatorFactory, proxyFactory); + + using var input = new MemoryStream(); + + if (arg is not null) + { + JsonSerializer.Serialize(input, arg); + + input.Seek(0, SeekOrigin.Begin); + } + + using var output = new MemoryStream(); + + await runtime.DispatchWithoutRemotingAsync(typeof(T).Name, ActorId.CreateRandom().ToString(), methodName, input, output); + + output.Seek(0, SeekOrigin.Begin); + + return JsonSerializer.Deserialize(output); + } + + [Fact] + public async Task NoRemotingMethodWithNoArguments() + { + string methodName = nameof(INotRemotedActor.NoArgumentsAsync); + + string result = await InvokeMethod(methodName); + + Assert.Equal(methodName, result); + } + + [Fact] + public async Task NoRemotingMethodWithNoArgumentsWithCancellation() + { + string methodName = nameof(INotRemotedActor.NoArgumentsWithCancellationAsync); + + string result = await InvokeMethod(methodName); + + Assert.Equal(methodName, result); + } + + [Fact] + public async Task NoRemotingMethodWithSingleArgument() + { + string methodName = nameof(INotRemotedActor.SingleArgumentAsync); + + string result = await InvokeMethod(methodName, true); + + Assert.Equal(methodName, result); + } + + [Fact] + public async Task NoRemotingMethodWithSingleArgumentWithCancellation() + { + string methodName = nameof(INotRemotedActor.SingleArgumentWithCancellationAsync); + + string result = await InvokeMethod(methodName, true); + + Assert.Equal(methodName, result); + } + [Fact] public async Task Actor_UsesCustomActivator() {