Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ public ValueTask DisposeAsync()
{
// Remove the data directory if it exists
TestDirectoryManager.CleanUpDirectory(_testDirectory);

// if (Directory.Exists(_hostDataDir))
// Directory.Delete(_hostDataDir, true);
return _container.DisposeAsync();
}
}
33 changes: 18 additions & 15 deletions src/Dapr.Workflow/Worker/Grpc/GrpcProtocolHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ internal sealed class GrpcProtocolHandler(TaskHubSidecarService.TaskHubSidecarSe
/// <param name="activityHandler">Handler for activity work items.</param>
/// <param name="cancellationToken">Cancellation token.</param>
public async Task StartAsync(
Func<OrchestratorRequest, Task<OrchestratorResponse>> workflowHandler,
Func<ActivityRequest, Task<ActivityResponse>> activityHandler,
Func<OrchestratorRequest, string, Task<OrchestratorResponse>> workflowHandler,
Func<ActivityRequest, string, Task<ActivityResponse>> activityHandler,
CancellationToken cancellationToken)
{
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposalCts.Token);
Expand Down Expand Up @@ -126,8 +126,8 @@ private static async Task DelayOrStopAsync(TimeSpan delay, CancellationToken tok
/// </summary>
private async Task ReceiveLoopAsync(
IAsyncStreamReader<WorkItem> workItemsStream,
Func<OrchestratorRequest, Task<OrchestratorResponse>> orchestratorHandler,
Func<ActivityRequest, Task<ActivityResponse>> activityHandler,
Func<OrchestratorRequest, string, Task<OrchestratorResponse>> orchestratorHandler,
Func<ActivityRequest, string, Task<ActivityResponse>> activityHandler,
CancellationToken cancellationToken)
{
// Track active work items for proper exception handling
Expand All @@ -137,14 +137,16 @@ private async Task ReceiveLoopAsync(
{
await foreach (var workItem in workItemsStream.ReadAllAsync(cancellationToken))
{
var completionToken = workItem.CompletionToken;

// Dispatch based on work item type
var workItemTask = workItem.RequestCase switch
{
WorkItem.RequestOneofCase.OrchestratorRequest => Task.Run(
() => ProcessWorkflowAsync(workItem.OrchestratorRequest, orchestratorHandler, cancellationToken),
() => ProcessWorkflowAsync(workItem.OrchestratorRequest, completionToken, orchestratorHandler, cancellationToken),
cancellationToken),
WorkItem.RequestOneofCase.ActivityRequest => Task.Run(
() => ProcessActivityAsync(workItem.ActivityRequest, activityHandler, cancellationToken),
() => ProcessActivityAsync(workItem.ActivityRequest, completionToken, activityHandler, cancellationToken),
cancellationToken),
_ => Task.Run(
() => _logger.LogGrpcProtocolHandlerUnknownWorkItemType(workItem.RequestCase),
Expand Down Expand Up @@ -188,16 +190,16 @@ private async Task ReceiveLoopAsync(
/// <summary>
/// Processes a workflow request work item.
/// </summary>
private async Task ProcessWorkflowAsync(OrchestratorRequest request,
Func<OrchestratorRequest, Task<OrchestratorResponse>> handler, CancellationToken cancellationToken)
private async Task ProcessWorkflowAsync(OrchestratorRequest request, string completionToken,
Func<OrchestratorRequest, string, Task<OrchestratorResponse>> handler, CancellationToken cancellationToken)
{
var activeCount = Interlocked.Increment(ref _activeWorkItemCount);

try
{
_logger.LogGrpcProtocolHandlerWorkflowProcessorStart(request.InstanceId, activeCount);

var result = await handler(request);
var result = await handler(request, completionToken);

// Send the result back to Dapr
await _grpcClient.CompleteOrchestratorTaskAsync(result, cancellationToken: cancellationToken);
Expand Down Expand Up @@ -227,16 +229,16 @@ private async Task ProcessWorkflowAsync(OrchestratorRequest request,
/// <summary>
/// Processes an activity request work item.
/// </summary>
private async Task ProcessActivityAsync(ActivityRequest request,
Func<ActivityRequest, Task<ActivityResponse>> handler, CancellationToken cancellationToken)
private async Task ProcessActivityAsync(ActivityRequest request, string completionToken,
Func<ActivityRequest, string, Task<ActivityResponse>> handler, CancellationToken cancellationToken)
{
var activeCount = Interlocked.Increment(ref _activeWorkItemCount);

try
{
_logger.LogGrpcProtocolHandlerActivityProcessorStart(request.OrchestrationInstance.InstanceId, request.Name,
request.TaskId, activeCount);
var result = await handler(request);
var result = await handler(request, completionToken);

// Send the result back to Dapr
await _grpcClient.CompleteActivityTaskAsync(result, cancellationToken: cancellationToken);
Expand All @@ -252,7 +254,7 @@ private async Task ProcessActivityAsync(ActivityRequest request,

try
{
var failureResult = CreateActivityFailureResult(request, ex);
var failureResult = CreateActivityFailureResult(request, completionToken, ex);
await _grpcClient.CompleteActivityTaskAsync(failureResult, cancellationToken: cancellationToken);
}
catch (Exception resultEx)
Expand All @@ -269,11 +271,12 @@ private async Task ProcessActivityAsync(ActivityRequest request,
/// <summary>
/// Creates a failure response for an activity exception.
/// </summary>
private static ActivityResponse CreateActivityFailureResult(ActivityRequest request, Exception ex) =>
private static ActivityResponse CreateActivityFailureResult(ActivityRequest request, string completionToken, Exception ex) =>
new()
{

InstanceId = request.OrchestrationInstance.InstanceId,
TaskId = request.TaskId,
CompletionToken = completionToken,
FailureDetails = new()
{
ErrorType = ex.GetType().FullName ?? "Exception",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ internal sealed class WorkflowOrchestrationContext : WorkflowContext
private readonly List<HistoryEvent> _externalEventBuffer = [];
private readonly Dictionary<string, Queue<TaskCompletionSource<HistoryEvent>>> _externalEventSources = new(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<int, TaskCompletionSource<HistoryEvent>> _openTasks = [];
private readonly Dictionary<int, string> _taskIdToExecutionId = [];
private readonly Dictionary<string, int> _executionIdToTaskId = new(StringComparer.Ordinal);
private readonly SortedDictionary<int, OrchestratorAction> _pendingActions = [];
private readonly IWorkflowSerializer _workflowSerializer;
private readonly ILogger<WorkflowOrchestrationContext> _logger;
Expand All @@ -50,6 +52,7 @@ internal sealed class WorkflowOrchestrationContext : WorkflowContext
/// Key is taskedScheduledId/timerId/etc. as provided by the history event.
/// </summary>
private readonly Dictionary<int, HistoryEvent> _unmatchedCompletions = [];
private readonly Dictionary<string, HistoryEvent> _unmatchedCompletionsByExecutionId = new(StringComparer.Ordinal);


// Parse instance ID as GUID or generate one
Expand Down Expand Up @@ -105,11 +108,13 @@ public override async Task<T> CallActivityAsync<T>(string name, object? input =
{
ArgumentException.ThrowIfNullOrWhiteSpace(name);
var taskId = _sequenceNumber++;
var taskExecutionId = CreateTaskExecutionId(taskId, name);

var router = CreateRouter(options?.TargetAppId);

// If the completion arrived before we registered the task, consume it now
if (_unmatchedCompletions.Remove(taskId, out var earlyCompletion))
if (_unmatchedCompletionsByExecutionId.Remove(taskExecutionId, out var earlyCompletion) ||
_unmatchedCompletions.Remove(taskId, out earlyCompletion))
{
_logger.LogDebug("Found early completion in buffer for task {TaskId} ({ActivityName})", taskId, name);
return await HandleHistoryMatch<T>(name, earlyCompletion, taskId);
Expand All @@ -122,13 +127,16 @@ public override async Task<T> CallActivityAsync<T>(string name, object? input =
{
Name = name,
Input = _workflowSerializer.Serialize(input),
Router = router
Router = router,
TaskExecutionId = taskExecutionId
},
Router = router
});

var tcs = new TaskCompletionSource<HistoryEvent>();
_openTasks.Add(taskId, tcs);
_taskIdToExecutionId[taskId] = taskExecutionId;
_executionIdToTaskId[taskExecutionId] = taskId;

var historyEvent = await tcs.Task;
return await HandleHistoryMatch<T>(name, historyEvent, taskId);
Expand Down Expand Up @@ -440,6 +448,17 @@ private void HandleActionCompleted(HistoryEvent historyEvent, int taskId)
{
tcs.SetResult(historyEvent);
_openTasks.Remove(taskId);
RemoveTaskExecutionMapping(taskId);
return;
}

if (TryGetTaskExecutionId(historyEvent, out var taskExecutionId) &&
_executionIdToTaskId.TryGetValue(taskExecutionId, out var executionTaskId) &&
_openTasks.TryGetValue(executionTaskId, out var executionTcs))
{
executionTcs.SetResult(historyEvent);
_openTasks.Remove(executionTaskId);
RemoveTaskExecutionMapping(executionTaskId);
return;
}

Expand All @@ -454,6 +473,21 @@ private void HandleActionCompleted(HistoryEvent historyEvent, int taskId)

// Buffer the completion so the next replay pass can consume it when the workflow schedules/await the task.
// Ignore duplicates (first completion wins)
if (TryGetTaskExecutionId(historyEvent, out var unmatchedExecutionId))
{
if (!_unmatchedCompletionsByExecutionId.TryAdd(unmatchedExecutionId, historyEvent))
{
_logger.LogWarning(
"Received completion for unknown taskId {TaskId} in instance {InstanceId}. OpenTasks=[{OpenTasks}] EventType={EventType}",
taskId,
InstanceId,
string.Join(",", _openTasks.Keys),
historyEvent.EventTypeCase);
}

return;
}

if (!_unmatchedCompletions.TryAdd(taskId, historyEvent))
{
// If we get here, the runtime delivered a completion for an unknown task id.
Expand Down Expand Up @@ -519,6 +553,33 @@ private Task<T> HandleFailedActivityFromHistory<T>(string activityName, TaskFail
throw CreateTaskFailedException(failed);
}

private string CreateTaskExecutionId(int taskId, string name)
{
var seed = $"{InstanceId}|activity|{taskId}|{name}";
return CreateGuidFromName(_instanceGuid, Encoding.UTF8.GetBytes(seed)).ToString("N");
}

private static bool TryGetTaskExecutionId(HistoryEvent historyEvent, out string taskExecutionId)
{
taskExecutionId = historyEvent switch
{
{ TaskCompleted: { } completed } => completed.TaskExecutionId,
{ TaskFailed: { } failed } => failed.TaskExecutionId,
_ => string.Empty
};

return !string.IsNullOrWhiteSpace(taskExecutionId);
}

private void RemoveTaskExecutionMapping(int taskId)
{
if (_taskIdToExecutionId.TryGetValue(taskId, out var executionId))
{
_taskIdToExecutionId.Remove(taskId);
_executionIdToTaskId.Remove(executionId);
}
}

/// <summary>
/// Handles a child workflow that completed in the workflow history.
/// </summary>
Expand Down
26 changes: 20 additions & 6 deletions src/Dapr.Workflow/Worker/WorkflowWorker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ namespace Dapr.Workflow.Worker;
/// <summary>
/// Background service that processes workflow and activity work items from the Dapr sidecar.
/// </summary>
internal sealed class WorkflowWorker(TaskHubSidecarService.TaskHubSidecarServiceClient grpcClient, IWorkflowsFactory workflowsFactory, ILoggerFactory loggerFactory, IWorkflowSerializer workflowSerializer, IServiceProvider serviceProvider, WorkflowRuntimeOptions options) : BackgroundService
internal sealed class WorkflowWorker(
TaskHubSidecarService.TaskHubSidecarServiceClient grpcClient,
IWorkflowsFactory workflowsFactory,
ILoggerFactory loggerFactory,
IWorkflowSerializer workflowSerializer,
IServiceProvider serviceProvider,
WorkflowRuntimeOptions options) : BackgroundService
{
private readonly TaskHubSidecarService.TaskHubSidecarServiceClient _grpcClient = grpcClient ?? throw new ArgumentNullException(nameof(grpcClient));
private readonly IWorkflowsFactory _workflowsFactory = workflowsFactory ?? throw new ArgumentNullException(nameof(workflowsFactory));
private readonly IServiceProvider _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
private readonly ILogger<WorkflowWorker> _logger = loggerFactory?.CreateLogger<WorkflowWorker>() ?? throw new ArgumentNullException(nameof(loggerFactory));
private readonly ILogger<WorkflowWorker> _logger = loggerFactory.CreateLogger<WorkflowWorker>() ?? throw new ArgumentNullException(nameof(loggerFactory));
private readonly WorkflowRuntimeOptions _options = options ?? throw new ArgumentNullException(nameof(options));
private readonly IWorkflowSerializer _serializer = workflowSerializer ?? throw new ArgumentNullException(nameof(workflowSerializer));

Expand Down Expand Up @@ -66,7 +72,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken)
}
}

private async Task<OrchestratorResponse> HandleOrchestratorResponseAsync(OrchestratorRequest request)
private async Task<OrchestratorResponse> HandleOrchestratorResponseAsync(OrchestratorRequest request, string completionToken)
{
_logger.LogWorkerWorkflowHandleOrchestratorRequestStart(request.InstanceId);

Expand Down Expand Up @@ -162,7 +168,11 @@ private async Task<OrchestratorResponse> HandleOrchestratorResponseAsync(Orchest
context.ProcessEvents(request.NewEvents, false);

// Get all pending actions from the context
var response = new OrchestratorResponse { InstanceId = request.InstanceId };
var response = new OrchestratorResponse
{
InstanceId = request.InstanceId,
CompletionToken = completionToken
};

// Add all actions that were scheduled during workflow execution
response.Actions.AddRange(context.PendingActions);
Expand Down Expand Up @@ -215,6 +225,7 @@ private async Task<OrchestratorResponse> HandleOrchestratorResponseAsync(Orchest
OrchestrationStatus = OrchestrationStatus.Failed,
FailureDetails = new()
{
IsNonRetriable = true,
ErrorType = ex.GetType().FullName ?? "Exception",
ErrorMessage = ex.Message,
StackTrace = ex.StackTrace ?? string.Empty
Expand Down Expand Up @@ -253,7 +264,7 @@ private async Task<OrchestratorResponse> HandleOrchestratorResponseAsync(Orchest
}
}

private async Task<ActivityResponse> HandleActivityResponseAsync(ActivityRequest request)
private async Task<ActivityResponse> HandleActivityResponseAsync(ActivityRequest request, string completionToken)
{
_logger.LogWorkerWorkflowHandleActivityRequestStart(request.Name, request.OrchestrationInstance?.InstanceId, request.TaskId);

Expand All @@ -272,6 +283,7 @@ private async Task<ActivityResponse> HandleActivityResponseAsync(ActivityRequest
{
InstanceId = request.OrchestrationInstance?.InstanceId ?? string.Empty,
TaskId = request.TaskId,
CompletionToken = completionToken,
FailureDetails = new()
{
ErrorType = "ActivityNotFoundException",
Expand Down Expand Up @@ -310,7 +322,8 @@ private async Task<ActivityResponse> HandleActivityResponseAsync(ActivityRequest
{
InstanceId = request.OrchestrationInstance?.InstanceId ?? string.Empty,
TaskId = request.TaskId,
Result = outputJson
Result = outputJson,
CompletionToken = completionToken
};
}
catch (Exception ex)
Expand All @@ -321,6 +334,7 @@ private async Task<ActivityResponse> HandleActivityResponseAsync(ActivityRequest
{
InstanceId = request.OrchestrationInstance?.InstanceId ?? string.Empty,
TaskId = request.TaskId,
CompletionToken = completionToken,
FailureDetails = new()
{
ErrorType = ex.GetType().FullName ?? "Exception",
Expand Down
Loading
Loading