diff --git a/src/Dapr.Workflow/Worker/Grpc/GrpcProtocolHandler.cs b/src/Dapr.Workflow/Worker/Grpc/GrpcProtocolHandler.cs index f727a3315..59a164d99 100644 --- a/src/Dapr.Workflow/Worker/Grpc/GrpcProtocolHandler.cs +++ b/src/Dapr.Workflow/Worker/Grpc/GrpcProtocolHandler.cs @@ -35,13 +35,15 @@ internal sealed class GrpcProtocolHandler( { private static readonly TimeSpan ReconnectDelay = TimeSpan.FromSeconds(5); private static readonly TimeSpan KeepaliveInterval = TimeSpan.FromSeconds(30); - + private readonly CancellationTokenSource _disposalCts = new(); private readonly ILogger _logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); private readonly TaskHubSidecarService.TaskHubSidecarServiceClient _grpcClient = grpcClient ?? throw new ArgumentNullException(nameof(grpcClient)); private readonly int _maxConcurrentWorkItems = maxConcurrentWorkItems > 0 ? maxConcurrentWorkItems : throw new ArgumentOutOfRangeException(nameof(maxConcurrentWorkItems)); private readonly int _maxConcurrentActivities = maxConcurrentActivities > 0 ? maxConcurrentActivities : throw new ArgumentOutOfRangeException(nameof(maxConcurrentActivities)); + private readonly SemaphoreSlim _orchestrationSemaphore = new(maxConcurrentWorkItems, maxConcurrentWorkItems); + private readonly SemaphoreSlim _activitySemaphore = new(maxConcurrentActivities, maxConcurrentActivities); private AsyncServerStreamingCall? _streamingCall; private int _activeWorkItemCount; @@ -221,6 +223,7 @@ private async Task ReceiveLoopAsync( private async Task ProcessWorkflowAsync(OrchestratorRequest request, string completionToken, Func> handler, CancellationToken cancellationToken) { + await _orchestrationSemaphore.WaitAsync(cancellationToken); var activeCount = Interlocked.Increment(ref _activeWorkItemCount); try @@ -252,6 +255,7 @@ private async Task ProcessWorkflowAsync(OrchestratorRequest request, string comp } finally { + _orchestrationSemaphore.Release(); Interlocked.Decrement(ref _activeWorkItemCount); } } @@ -262,6 +266,7 @@ private async Task ProcessWorkflowAsync(OrchestratorRequest request, string comp private async Task ProcessActivityAsync(ActivityRequest request, string completionToken, Func> handler, CancellationToken cancellationToken) { + await _activitySemaphore.WaitAsync(cancellationToken); var activeCount = Interlocked.Increment(ref _activeWorkItemCount); try @@ -296,6 +301,7 @@ private async Task ProcessActivityAsync(ActivityRequest request, string completi } finally { + _activitySemaphore.Release(); Interlocked.Decrement(ref _activeWorkItemCount); } } @@ -368,13 +374,15 @@ public async ValueTask DisposeAsync() { if (_disposalCts.IsCancellationRequested) return; - + _logger.LogGrpcProtocolHandlerDisposing(); - + await _disposalCts.CancelAsync(); _streamingCall?.Dispose(); _disposalCts.Dispose(); - + _orchestrationSemaphore.Dispose(); + _activitySemaphore.Dispose(); + _logger.LogGrpcProtocolHandlerDisposed(); } diff --git a/test/Dapr.IntegrationTest.Workflow/MaxConcurrentActivitiesTests.cs b/test/Dapr.IntegrationTest.Workflow/MaxConcurrentActivitiesTests.cs new file mode 100644 index 000000000..ae13e38ec --- /dev/null +++ b/test/Dapr.IntegrationTest.Workflow/MaxConcurrentActivitiesTests.cs @@ -0,0 +1,203 @@ +// ------------------------------------------------------------------------ +// Copyright 2025 The Dapr Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + +using Dapr.Testcontainers.Common; +using Dapr.Testcontainers.Harnesses; +using Dapr.Workflow; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; + +namespace Dapr.IntegrationTest.Workflow; + +public sealed class MaxConcurrentActivitiesTests +{ + /// + /// Verifies that = 1 limits + /// activity execution to a single concurrent activity even when the workflow fans out more. + /// + [Fact] + public async Task ShouldRespectMaxConcurrentActivitiesLimitOfOne() + { + const int limit = 1; + const int activityCount = 5; + + var componentsDir = TestDirectoryManager.CreateTestDirectory("workflow-components"); + var workflowInstanceId = Guid.NewGuid().ToString(); + + await using var environment = await DaprTestEnvironment.CreateWithPooledNetworkAsync( + needsActorState: true, + cancellationToken: TestContext.Current.CancellationToken); + await environment.StartAsync(TestContext.Current.CancellationToken); + + var harness = new DaprHarnessBuilder(componentsDir) + .WithEnvironment(environment) + .BuildWorkflow(); + + await using var testApp = await DaprHarnessBuilder.ForHarness(harness) + .ConfigureServices(builder => + { + builder.Services.AddDaprWorkflowBuilder( + configureRuntime: opt => + { + opt.MaxConcurrentActivities = limit; + opt.RegisterWorkflow(); + opt.RegisterActivity(); + }, + configureClient: (sp, clientBuilder) => + { + var config = sp.GetRequiredService(); + var grpcEndpoint = config["DAPR_GRPC_ENDPOINT"]; + if (!string.IsNullOrEmpty(grpcEndpoint)) + clientBuilder.UseGrpcEndpoint(grpcEndpoint); + }); + }) + .BuildAndStartAsync(); + + ConcurrencyTrackingActivity.Reset(); + + using var scope = testApp.CreateScope(); + var daprWorkflowClient = scope.ServiceProvider.GetRequiredService(); + + await daprWorkflowClient.ScheduleNewWorkflowAsync(nameof(FanOutWorkflow), workflowInstanceId, activityCount); + var result = await daprWorkflowClient.WaitForWorkflowCompletionAsync( + workflowInstanceId, true, TestContext.Current.CancellationToken); + + Assert.Equal(WorkflowRuntimeStatus.Completed, result.RuntimeStatus); + Assert.True( + ConcurrencyTrackingActivity.MaxObservedConcurrency <= limit, + $"Expected max concurrent activities <= {limit}, but observed {ConcurrencyTrackingActivity.MaxObservedConcurrency}"); + } + + /// + /// Verifies that = 3 allows up to + /// 3 concurrent activities and that all activities complete successfully. + /// + [Fact] + public async Task ShouldRespectMaxConcurrentActivitiesLimitOfThree() + { + const int limit = 3; + const int activityCount = 10; + + var componentsDir = TestDirectoryManager.CreateTestDirectory("workflow-components"); + var workflowInstanceId = Guid.NewGuid().ToString(); + + await using var environment = await DaprTestEnvironment.CreateWithPooledNetworkAsync( + needsActorState: true, + cancellationToken: TestContext.Current.CancellationToken); + await environment.StartAsync(TestContext.Current.CancellationToken); + + var harness = new DaprHarnessBuilder(componentsDir) + .WithEnvironment(environment) + .BuildWorkflow(); + + await using var testApp = await DaprHarnessBuilder.ForHarness(harness) + .ConfigureServices(builder => + { + builder.Services.AddDaprWorkflowBuilder( + configureRuntime: opt => + { + opt.MaxConcurrentActivities = limit; + opt.RegisterWorkflow(); + opt.RegisterActivity(); + }, + configureClient: (sp, clientBuilder) => + { + var config = sp.GetRequiredService(); + var grpcEndpoint = config["DAPR_GRPC_ENDPOINT"]; + if (!string.IsNullOrEmpty(grpcEndpoint)) + clientBuilder.UseGrpcEndpoint(grpcEndpoint); + }); + }) + .BuildAndStartAsync(); + + ConcurrencyTrackingActivity.Reset(); + + using var scope = testApp.CreateScope(); + var daprWorkflowClient = scope.ServiceProvider.GetRequiredService(); + + await daprWorkflowClient.ScheduleNewWorkflowAsync(nameof(FanOutWorkflow), workflowInstanceId, activityCount); + var result = await daprWorkflowClient.WaitForWorkflowCompletionAsync( + workflowInstanceId, true, TestContext.Current.CancellationToken); + + Assert.Equal(WorkflowRuntimeStatus.Completed, result.RuntimeStatus); + Assert.True( + ConcurrencyTrackingActivity.MaxObservedConcurrency <= limit, + $"Expected max concurrent activities <= {limit}, but observed {ConcurrencyTrackingActivity.MaxObservedConcurrency}"); + } + + /// + /// Tracks the maximum number of concurrently executing activity instances using a shared + /// static counter. Each activity holds for a brief period so concurrent executions can + /// accumulate and be observed. + /// + private sealed class ConcurrencyTrackingActivity : WorkflowActivity + { + private static int _currentConcurrent; + private static int _maxObservedConcurrent; + private static readonly object Lock = new(); + + public static int MaxObservedConcurrency + { + get + { + lock (Lock) + { + return _maxObservedConcurrent; + } + } + } + + public static void Reset() + { + lock (Lock) + { + _currentConcurrent = 0; + _maxObservedConcurrent = 0; + } + } + + public override async Task RunAsync(WorkflowActivityContext context, int input) + { + lock (Lock) + { + _currentConcurrent++; + if (_currentConcurrent > _maxObservedConcurrent) + _maxObservedConcurrent = _currentConcurrent; + } + + // Hold briefly so any concurrent activities can be observed accumulating. + await Task.Delay(TimeSpan.FromMilliseconds(300)); + + lock (Lock) + { + _currentConcurrent--; + } + + return input; + } + } + + private sealed class FanOutWorkflow : Workflow + { + public override async Task RunAsync(WorkflowContext context, int input) + { + var tasks = new Task[input]; + for (var i = 0; i < input; i++) + { + tasks[i] = context.CallActivityAsync(nameof(ConcurrencyTrackingActivity), i); + } + + return await Task.WhenAll(tasks); + } + } +} diff --git a/test/Dapr.IntegrationTest.Workflow/MaxConcurrentWorkflowsTests.cs b/test/Dapr.IntegrationTest.Workflow/MaxConcurrentWorkflowsTests.cs new file mode 100644 index 000000000..c1e34a998 --- /dev/null +++ b/test/Dapr.IntegrationTest.Workflow/MaxConcurrentWorkflowsTests.cs @@ -0,0 +1,156 @@ +// ------------------------------------------------------------------------ +// Copyright 2025 The Dapr Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + +using Dapr.Testcontainers.Common; +using Dapr.Testcontainers.Harnesses; +using Dapr.Workflow; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; + +namespace Dapr.IntegrationTest.Workflow; + +public sealed class MaxConcurrentWorkflowsTests +{ + /// + /// Verifies that setting to 1 + /// does not deadlock the runtime and that all scheduled workflows eventually complete. + /// + [Fact] + public async Task ShouldCompleteAllWorkflowsWhenLimitIsOne() + { + const int workflowCount = 3; + + var componentsDir = TestDirectoryManager.CreateTestDirectory("workflow-components"); + var workflowInstanceIds = Enumerable.Range(0, workflowCount) + .Select(_ => Guid.NewGuid().ToString()) + .ToArray(); + + await using var environment = await DaprTestEnvironment.CreateWithPooledNetworkAsync( + needsActorState: true, + cancellationToken: TestContext.Current.CancellationToken); + await environment.StartAsync(TestContext.Current.CancellationToken); + + var harness = new DaprHarnessBuilder(componentsDir) + .WithEnvironment(environment) + .BuildWorkflow(); + + await using var testApp = await DaprHarnessBuilder.ForHarness(harness) + .ConfigureServices(builder => + { + builder.Services.AddDaprWorkflowBuilder( + configureRuntime: opt => + { + opt.MaxConcurrentWorkflows = 1; + opt.RegisterWorkflow(); + opt.RegisterActivity(); + }, + configureClient: (sp, clientBuilder) => + { + var config = sp.GetRequiredService(); + var grpcEndpoint = config["DAPR_GRPC_ENDPOINT"]; + if (!string.IsNullOrEmpty(grpcEndpoint)) + clientBuilder.UseGrpcEndpoint(grpcEndpoint); + }); + }) + .BuildAndStartAsync(); + + using var scope = testApp.CreateScope(); + var daprWorkflowClient = scope.ServiceProvider.GetRequiredService(); + + // Schedule all workflows concurrently. + await Task.WhenAll(workflowInstanceIds.Select(id => + daprWorkflowClient.ScheduleNewWorkflowAsync(nameof(EchoWorkflow), id, id))); + + // Wait for all to finish and assert each completed successfully. + var results = await Task.WhenAll(workflowInstanceIds.Select(id => + daprWorkflowClient.WaitForWorkflowCompletionAsync(id, true, TestContext.Current.CancellationToken))); + + foreach (var (result, id) in results.Zip(workflowInstanceIds)) + { + Assert.Equal(WorkflowRuntimeStatus.Completed, result.RuntimeStatus); + Assert.Equal(id, result.ReadOutputAs()); + } + } + + /// + /// Verifies that a custom value + /// greater than 1 does not deadlock the runtime and that all scheduled workflows complete. + /// + [Fact] + public async Task ShouldCompleteAllWorkflowsWithCustomConcurrencyLimit() + { + const int limit = 2; + const int workflowCount = 5; + + var componentsDir = TestDirectoryManager.CreateTestDirectory("workflow-components"); + var workflowInstanceIds = Enumerable.Range(0, workflowCount) + .Select(_ => Guid.NewGuid().ToString()) + .ToArray(); + + await using var environment = await DaprTestEnvironment.CreateWithPooledNetworkAsync( + needsActorState: true, + cancellationToken: TestContext.Current.CancellationToken); + await environment.StartAsync(TestContext.Current.CancellationToken); + + var harness = new DaprHarnessBuilder(componentsDir) + .WithEnvironment(environment) + .BuildWorkflow(); + + await using var testApp = await DaprHarnessBuilder.ForHarness(harness) + .ConfigureServices(builder => + { + builder.Services.AddDaprWorkflowBuilder( + configureRuntime: opt => + { + opt.MaxConcurrentWorkflows = limit; + opt.RegisterWorkflow(); + opt.RegisterActivity(); + }, + configureClient: (sp, clientBuilder) => + { + var config = sp.GetRequiredService(); + var grpcEndpoint = config["DAPR_GRPC_ENDPOINT"]; + if (!string.IsNullOrEmpty(grpcEndpoint)) + clientBuilder.UseGrpcEndpoint(grpcEndpoint); + }); + }) + .BuildAndStartAsync(); + + using var scope = testApp.CreateScope(); + var daprWorkflowClient = scope.ServiceProvider.GetRequiredService(); + + await Task.WhenAll(workflowInstanceIds.Select(id => + daprWorkflowClient.ScheduleNewWorkflowAsync(nameof(EchoWorkflow), id, id))); + + var results = await Task.WhenAll(workflowInstanceIds.Select(id => + daprWorkflowClient.WaitForWorkflowCompletionAsync(id, true, TestContext.Current.CancellationToken))); + + foreach (var (result, id) in results.Zip(workflowInstanceIds)) + { + Assert.Equal(WorkflowRuntimeStatus.Completed, result.RuntimeStatus); + Assert.Equal(id, result.ReadOutputAs()); + } + } + + private sealed class EchoActivity : WorkflowActivity + { + public override Task RunAsync(WorkflowActivityContext context, string input) => + Task.FromResult(input); + } + + private sealed class EchoWorkflow : Workflow + { + public override async Task RunAsync(WorkflowContext context, string input) => + await context.CallActivityAsync(nameof(EchoActivity), input); + } +} diff --git a/test/Dapr.Workflow.Test/Worker/Grpc/GrpcProtocolHandlerTests.cs b/test/Dapr.Workflow.Test/Worker/Grpc/GrpcProtocolHandlerTests.cs index a4034deaf..ef5e73829 100644 --- a/test/Dapr.Workflow.Test/Worker/Grpc/GrpcProtocolHandlerTests.cs +++ b/test/Dapr.Workflow.Test/Worker/Grpc/GrpcProtocolHandlerTests.cs @@ -701,6 +701,151 @@ await RunHandlerUntilAsync( Assert.True(Volatile.Read(ref getWorkItemsCalls) >= 1); } + [Fact] + public async Task StartAsync_ShouldNotExceedMaxConcurrentOrchestrationWorkItems() + { + const int maxConcurrent = 2; + const int totalItems = 3; + + var grpcClientMock = CreateGrpcClientMock(); + + grpcClientMock + .Setup(x => x.GetWorkItems(It.IsAny(), It.IsAny())) + .Returns(CreateServerStreamingCall(Enumerable.Range(1, totalItems) + .Select(i => new WorkItem + { + OrchestratorRequest = new OrchestratorRequest { InstanceId = $"i-{i}" } + }))); + + grpcClientMock + .Setup(x => x.CompleteOrchestratorTaskAsync(It.IsAny(), It.IsAny())) + .Returns(CreateAsyncUnaryCall(new CompleteTaskResponse())); + + var activeCount = 0; + var completedCount = 0; + using var releaseGate = new SemaphoreSlim(0); + var maxConcurrentReachedTcs = CreateTcs(); + + var handler = new GrpcProtocolHandler( + grpcClientMock.Object, + NullLoggerFactory.Instance, + maxConcurrentWorkItems: maxConcurrent, + maxConcurrentActivities: 1); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + + var startTask = handler.StartAsync( + workflowHandler: async (req, _) => + { + var count = Interlocked.Increment(ref activeCount); + if (count == maxConcurrent) + maxConcurrentReachedTcs.TrySetResult(true); + + await releaseGate.WaitAsync(TestContext.Current.CancellationToken); + + Interlocked.Decrement(ref activeCount); + Interlocked.Increment(ref completedCount); + return new OrchestratorResponse { InstanceId = req.InstanceId }; + }, + activityHandler: (_, _) => Task.FromResult(new ActivityResponse()), + cancellationToken: cts.Token); + + // Wait for maxConcurrent handlers to be simultaneously active + await maxConcurrentReachedTcs.Task.WaitAsync(TimeSpan.FromSeconds(3), TestContext.Current.CancellationToken); + + // The orchestration semaphore prevents a 3rd handler from starting + Assert.Equal(maxConcurrent, Volatile.Read(ref activeCount)); + + // Release all handlers (including the one queued behind the semaphore) + releaseGate.Release(totalItems); + + // Wait for all to finish + var deadline = DateTime.UtcNow.AddSeconds(3); + while (Volatile.Read(ref completedCount) < totalItems && DateTime.UtcNow < deadline) + await Task.Delay(10, TestContext.Current.CancellationToken); + + cts.Cancel(); + await startTask; + + Assert.Equal(totalItems, Volatile.Read(ref completedCount)); + } + + [Fact] + public async Task StartAsync_ShouldNotExceedMaxConcurrentActivityWorkItems() + { + const int maxConcurrent = 2; + const int totalItems = 3; + + var grpcClientMock = CreateGrpcClientMock(); + + grpcClientMock + .Setup(x => x.GetWorkItems(It.IsAny(), It.IsAny())) + .Returns(CreateServerStreamingCall(Enumerable.Range(1, totalItems) + .Select(i => new WorkItem + { + ActivityRequest = new ActivityRequest + { + Name = $"act-{i}", + TaskId = i, + OrchestrationInstance = new OrchestrationInstance { InstanceId = "i-1" } + } + }))); + + grpcClientMock + .Setup(x => x.CompleteActivityTaskAsync(It.IsAny(), It.IsAny())) + .Returns(CreateAsyncUnaryCall(new CompleteTaskResponse())); + + var activeCount = 0; + var completedCount = 0; + using var releaseGate = new SemaphoreSlim(0); + var maxConcurrentReachedTcs = CreateTcs(); + + var handler = new GrpcProtocolHandler( + grpcClientMock.Object, + NullLoggerFactory.Instance, + maxConcurrentWorkItems: 1, + maxConcurrentActivities: maxConcurrent); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + + var startTask = handler.StartAsync( + workflowHandler: (_, _) => Task.FromResult(new OrchestratorResponse()), + activityHandler: async (req, _) => + { + var count = Interlocked.Increment(ref activeCount); + if (count == maxConcurrent) + maxConcurrentReachedTcs.TrySetResult(true); + + await releaseGate.WaitAsync(TestContext.Current.CancellationToken); + + Interlocked.Decrement(ref activeCount); + Interlocked.Increment(ref completedCount); + return new ActivityResponse { InstanceId = req.OrchestrationInstance.InstanceId, TaskId = req.TaskId }; + }, + cancellationToken: cts.Token); + + // Wait for maxConcurrent handlers to be simultaneously active + await maxConcurrentReachedTcs.Task.WaitAsync(TimeSpan.FromSeconds(3), TestContext.Current.CancellationToken); + + // The activity semaphore prevents a 3rd handler from starting + Assert.Equal(maxConcurrent, Volatile.Read(ref activeCount)); + + // Release all handlers (including the one queued behind the semaphore) + releaseGate.Release(totalItems); + + // Wait for all to finish + var deadline = DateTime.UtcNow.AddSeconds(3); + while (Volatile.Read(ref completedCount) < totalItems && DateTime.UtcNow < deadline) + await Task.Delay(10, TestContext.Current.CancellationToken); + + cts.Cancel(); + await startTask; + + Assert.Equal(totalItems, Volatile.Read(ref completedCount)); + } + [Fact] public async Task DelayOrStopAsync_ShouldSwallowCancellation_WhenTokenIsCanceled() {