Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,61 @@
using System.Collections.Generic;
using System.Linq;
using System.Text.Json.Serialization;
using System.Threading;
using Microsoft.Agents.AI.Workflows.Checkpointing;

namespace Microsoft.Agents.AI.Workflows.Execution;

internal sealed class FanInEdgeState
{
private List<PortableMessageEnvelope> _pendingMessages;
private readonly object _syncLock = new();

public FanInEdgeState(FanInEdgeData fanInEdge)
{
this.SourceIds = fanInEdge.SourceIds.ToArray();
this.Unseen = [.. this.SourceIds];

this._pendingMessages = [];
this.PendingMessages = [];
}

public string[] SourceIds { get; }
public HashSet<string> Unseen { get; private set; }
public List<PortableMessageEnvelope> PendingMessages => this._pendingMessages;
public List<PortableMessageEnvelope> PendingMessages { get; private set; }

[JsonConstructor]
public FanInEdgeState(string[] sourceIds, HashSet<string> unseen, List<PortableMessageEnvelope> pendingMessages)
{
this.SourceIds = sourceIds;
this.Unseen = unseen;

this._pendingMessages = pendingMessages;
this.PendingMessages = pendingMessages;
}

public IEnumerable<IGrouping<ExecutorIdentity, MessageEnvelope>>? ProcessMessage(string sourceId, MessageEnvelope envelope)
{
this.PendingMessages.Add(new(envelope));
this.Unseen.Remove(sourceId);
List<PortableMessageEnvelope>? takenMessages = null;

if (this.Unseen.Count == 0)
// Serialize concurrent calls from parallel executor tasks during superstep execution.
// NOTE - IMPORTANT: If this ProcessMessage method ever becomes async, replace this lock with an async friendly solution to avoid deadlocks.
lock (this._syncLock)
Comment thread
lokitoth marked this conversation as resolved.
{
List<PortableMessageEnvelope> takenMessages = Interlocked.Exchange(ref this._pendingMessages, []);
this.Unseen = [.. this.SourceIds];
this.PendingMessages.Add(new(envelope));
this.Unseen.Remove(sourceId);

if (takenMessages.Count == 0)
if (this.Unseen.Count == 0)
{
return null;
takenMessages = this.PendingMessages;
this.PendingMessages = [];
this.Unseen = [.. this.SourceIds];
}
}

return takenMessages.Select(portable => portable.ToMessageEnvelope())
.GroupBy(keySelector: messageEnvelope => messageEnvelope.Source);
if (takenMessages is null || takenMessages.Count == 0)
{
return null;
}

return null;
return takenMessages
.Select(portable => portable.ToMessageEnvelope())
.GroupBy(messageEnvelope => messageEnvelope.Source);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
Expand Down Expand Up @@ -199,4 +200,43 @@ async ValueTask RunIterationAsync()
mapping.CheckDeliveries(["executor3"], ["part1", "part2", "final part"]);
}
}

[Fact]
public async Task Test_FanInEdgeRunner_ConcurrentProcessingAsync()
{
// Arrange
const int SourceCount = 4;
const int Iterations = 50;

string[] sourceIds = Enumerable.Range(0, SourceCount).Select(i => $"source{i}").ToArray();
const string SinkId = "sink";

TestRunContext runContext = new();
List<Executor> executors = [.. sourceIds.Select(id => (Executor)new ForwardMessageExecutor<string>(id)), new ForwardMessageExecutor<string>(SinkId)];
runContext.ConfigureExecutors(executors);

FanInEdgeData edgeData = new(sourceIds.ToList(), SinkId, new EdgeId(0), null);
FanInEdgeRunner runner = new(runContext, edgeData);

for (int iteration = 0; iteration < Iterations; iteration++)
{
// Act: send messages from all sources concurrently
using Barrier barrier = new(SourceCount);
Task<DeliveryMapping?>[] tasks = sourceIds.Select(sourceId => Task.Run(async () =>
{
barrier.SignalAndWait();
return await runner.ChaseEdgeAsync(new($"msg-from-{sourceId}", sourceId), stepTracer: null, CancellationToken.None);
})).ToArray();

DeliveryMapping?[] results = await Task.WhenAll(tasks);

// Assert: exactly one task should return a non-null mapping with all messages
DeliveryMapping?[] nonNullResults = results.Where(r => r is not null).ToArray();
nonNullResults.Should().HaveCount(1, $"iteration {iteration}: exactly one thread should release the batch");

DeliveryMapping mapping = nonNullResults[0]!;
HashSet<object> expectedMessages = [.. sourceIds.Select(id => (object)$"msg-from-{id}")];
mapping.CheckDeliveries([SinkId], expectedMessages);
}
}
}
Loading