Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rysweet 4541 fix persistence of subscription grain was related to 4491 #4570

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
18 changes: 8 additions & 10 deletions dotnet/src/Microsoft.AutoGen/Agents/Services/Grpc/GrpcGateway.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
using Microsoft.Extensions.Logging;

namespace Microsoft.AutoGen.Agents;

public sealed class GrpcGateway : BackgroundService, IGateway
{
private static readonly TimeSpan s_agentResponseTimeout = TimeSpan.FromSeconds(30);
private readonly ILogger<GrpcGateway> _logger;
private readonly IClusterClient _clusterClient;
private readonly ConcurrentDictionary<string, AgentState> _agentState = new();
private readonly IRegistryGrain _gatewayRegistry;
private readonly ISubscriptionsGrain _subscriptions;
private readonly ISubscriptionsGrain _subscriptionsGrain;
private readonly IGateway _reference;
// The agents supported by each worker process.
private SubscriptionsState _subscriptionsState = new();
private readonly ConcurrentDictionary<string, List<GrpcWorkerConnection>> _supportedAgentTypes = [];
public readonly ConcurrentDictionary<IConnection, IConnection> _workers = new();
private readonly ConcurrentDictionary<string, Subscription> _subscriptionsByAgentType = new();
private readonly ConcurrentDictionary<string, List<string>> _subscriptionsByTopic = new();

// The mapping from agent id to worker process.
private readonly ConcurrentDictionary<(string Type, string Key), GrpcWorkerConnection> _agentDirectory = new();
Expand All @@ -36,7 +34,7 @@ public GrpcGateway(IClusterClient clusterClient, ILogger<GrpcGateway> logger)
_clusterClient = clusterClient;
_reference = clusterClient.CreateObjectReference<IGateway>(this);
_gatewayRegistry = clusterClient.GetGrain<IRegistryGrain>(0);
_subscriptions = clusterClient.GetGrain<ISubscriptionsGrain>(0);
_subscriptionsGrain = clusterClient.GetGrain<ISubscriptionsGrain>(0);
}
public async ValueTask BroadcastEvent(CloudEvent evt)
{
Expand Down Expand Up @@ -135,10 +133,9 @@ private async ValueTask AddSubscriptionAsync(GrpcWorkerConnection connection, Ad
topic = request.Subscription.TypeSubscription.TopicType;
agentType = request.Subscription.TypeSubscription.AgentType;
}
_subscriptionsByAgentType[agentType] = request.Subscription;
_subscriptionsByTopic.GetOrAdd(topic, _ => []).Add(agentType);
await _subscriptions.SubscribeAsync(topic, agentType);
//var response = new AddSubscriptionResponse { RequestId = request.RequestId, Error = "", Success = true };
await _subscriptionsGrain.SubscribeAsync(topic, agentType).ConfigureAwait(true);
_subscriptionsState = await _subscriptionsGrain.GetSubscriptionsStateAsync().ConfigureAwait(true);

Message response = new()
{
AddSubscriptionResponse = new()
Expand Down Expand Up @@ -169,12 +166,13 @@ private async ValueTask RegisterAgentTypeAsync(GrpcWorkerConnection connection,
}
private async ValueTask DispatchEventAsync(CloudEvent evt)
{
var _subscriptionsByTopic = await _subscriptionsGrain.GetSubscriptionsByTopicAsync().ConfigureAwait(true);
// get the event type and then send to all agents that are subscribed to that event type
var eventType = evt.Type;
// ensure that we get agentTypes as an async enumerable list - try to get the value of agentTypes by topic and then cast it to an async enumerable list
if (_subscriptionsByTopic.TryGetValue(eventType, out var agentTypes))
{
await DispatchEventToAgentsAsync(agentTypes, evt);
await DispatchEventToAgentsAsync(agentTypes, evt).ConfigureAwait(false);
}
// instead of an exact match, we can also check for a prefix match where key starts with the eventType
else if (_subscriptionsByTopic.Keys.Any(key => key.StartsWith(eventType)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

namespace Microsoft.AutoGen.Agents;

internal sealed class GrpcWorkerConnection : IAsyncDisposable, IConnection
public sealed class GrpcWorkerConnection : IAsyncDisposable, IConnection
{
private static long s_nextConnectionId;
private readonly Task _readTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag,
if ((string.IsNullOrEmpty(state.Etag)) || (string.IsNullOrEmpty(eTag)) || (string.Equals(state.Etag, eTag, StringComparison.Ordinal)))
{
state.State = newState;
await state.WriteStateAsync().ConfigureAwait(false);
await state.WriteStateAsync().ConfigureAwait(true);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ISubscriptionsGrain.cs

using System.Collections.Concurrent;
using Microsoft.AutoGen.Abstractions;

namespace Microsoft.AutoGen.Agents;

[Alias("Microsoft.AutoGen.Agents.ISubscriptionsGrain")]
public interface ISubscriptionsGrain : IGrainWithIntegerKey
{
[Alias("SubscribeAsync")]
ValueTask SubscribeAsync(string agentType, string topic);
[Alias("UnsubscribeAsync")]
ValueTask UnsubscribeAsync(string agentType, string topic);
ValueTask<Dictionary<string, List<string>>> GetSubscriptions(string agentType);
[Alias("GetSubscriptionsAsync")]
ValueTask<ConcurrentDictionary<string, List<string>>> GetSubscriptionsByAgentTypeAsync(string? agentType = null);
[Alias ("GetSubscriptionsByTopicAsync")]
ValueTask<ConcurrentDictionary<string, List<string>>> GetSubscriptionsByTopicAsync(string? topic = null);
[Alias("GetSubscriptionsByAgentTypeAsync")]
ValueTask<SubscriptionsState> GetSubscriptionsStateAsync();
[Alias("WriteSubscriptionsStateAsync")]
ValueTask WriteSubscriptionsStateAsync(SubscriptionsState subscriptionsState);
}
Original file line number Diff line number Diff line change
@@ -1,50 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SubscriptionsGrain.cs
using System.Collections.Concurrent;
using Microsoft.AutoGen.Abstractions;

namespace Microsoft.AutoGen.Agents;

internal sealed class SubscriptionsGrain([PersistentState("state", "PubSubStore")] IPersistentState<SubscriptionsState> state) : Grain, ISubscriptionsGrain
internal sealed class SubscriptionsGrain([PersistentState("state", "PubSubStore")] IPersistentState<SubscriptionsState> subscriptionsState) : Grain, ISubscriptionsGrain
{
private readonly Dictionary<string, List<string>> _subscriptions = new();
public ValueTask<Dictionary<string, List<string>>> GetSubscriptions(string? agentType = null)
private readonly IPersistentState<SubscriptionsState> _subscriptionsState = subscriptionsState;

public ValueTask<ConcurrentDictionary<string, List<string>>> GetSubscriptionsByAgentTypeAsync(string? agentType = null)
{
var _subscriptions = _subscriptionsState.State.SubscriptionsByAgentType;
//if agentType is null, return all subscriptions else filter on agentType
if (agentType != null)
{
return new ValueTask<Dictionary<string, List<string>>>(_subscriptions.Where(x => x.Value.Contains(agentType)).ToDictionary(x => x.Key, x => x.Value));
var filteredSubscriptions = _subscriptions.Where(x => x.Value.Contains(agentType));
return ValueTask.FromResult<ConcurrentDictionary<string, List<string>>>((ConcurrentDictionary<string, List<string>>)filteredSubscriptions);
}
return new ValueTask<Dictionary<string, List<string>>>(_subscriptions);
return ValueTask.FromResult<ConcurrentDictionary<string, List<string>>>(_subscriptions);
}
public async ValueTask SubscribeAsync(string agentType, string topic)
public ValueTask<ConcurrentDictionary<string, List<string>>> GetSubscriptionsByTopicAsync(string? topic = null)
{
if (!_subscriptions.TryGetValue(topic, out var subscriptions))
{
subscriptions = _subscriptions[topic] = [];
}
if (!subscriptions.Contains(agentType))
var _subscriptions = _subscriptionsState.State.SubscriptionsByTopic;
//if topic is null, return all subscriptions else filter on topic
if (topic != null)
{
subscriptions.Add(agentType);
var filteredSubscriptions = _subscriptions.Where(x => x.Key == topic);
return ValueTask.FromResult<ConcurrentDictionary<string, List<string>>>((ConcurrentDictionary<string, List<string>>)filteredSubscriptions);
}
_subscriptions[topic] = subscriptions;
state.State.Subscriptions = _subscriptions;
await state.WriteStateAsync().ConfigureAwait(false);
return ValueTask.FromResult<ConcurrentDictionary<string, List<string>>>(_subscriptions);
}
public ValueTask<SubscriptionsState> GetSubscriptionsStateAsync() => ValueTask.FromResult(_subscriptionsState.State);

public async ValueTask SubscribeAsync(string agentType, string topic)
{
await WriteSubscriptionsAsync(agentType: agentType, topic: topic, subscribe: true).ConfigureAwait(false);
}
public async ValueTask UnsubscribeAsync(string agentType, string topic)
{
if (!_subscriptions.TryGetValue(topic, out var subscriptions))
await WriteSubscriptionsAsync(agentType: agentType, topic: topic, subscribe: false).ConfigureAwait(false);
}
public async ValueTask WriteSubscriptionsStateAsync(SubscriptionsState subscriptionsState)
{
_subscriptionsState.State = subscriptionsState;
await _subscriptionsState.WriteStateAsync().ConfigureAwait(true);
}

private async ValueTask WriteSubscriptionsAsync(string agentType, string topic, bool subscribe=true)
{
var _subscriptions = await GetSubscriptionsByAgentTypeAsync().ConfigureAwait(true);
if (!_subscriptions.TryGetValue(topic, out var agentTypes))
{
subscriptions = _subscriptions[topic] = [];
agentTypes = _subscriptions[topic] = [];
}
if (!subscriptions.Contains(agentType))
if (!agentTypes.Contains(agentType))
{
subscriptions.Remove(agentType);
if (subscribe)
{
agentTypes.Add(agentType);
}
else
{
agentTypes.Remove(agentType);
}
}
_subscriptions[topic] = subscriptions;
state.State.Subscriptions = _subscriptions;
await state.WriteStateAsync();
_subscriptionsState.State.SubscriptionsByAgentType = _subscriptions;
var _subsByTopic = await GetSubscriptionsByTopicAsync().ConfigureAwait(true);
_subsByTopic.GetOrAdd(topic, _ => []).Add(agentType);
_subscriptionsState.State.SubscriptionsByTopic = _subsByTopic;
await _subscriptionsState.WriteStateAsync().ConfigureAwait(false);
}
}
public sealed class SubscriptionsState
{
public Dictionary<string, List<string>> Subscriptions { get; set; } = new();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// SubscriptionsState.cs
using System.Collections.Concurrent;

namespace Microsoft.AutoGen.Abstractions;
[GenerateSerializer]
[Serializable]
public sealed class SubscriptionsState
{
public ConcurrentDictionary<string, List<string>> SubscriptionsByTopic = new();
public ConcurrentDictionary<string, List<string>> SubscriptionsByAgentType { get; set; } = new();
}
102 changes: 102 additions & 0 deletions dotnet/test/Microsoft.AutoGen.Agents.Tests/ISubscriptionsGrainTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// ISubscriptionsGrainTests.cs

using System.Collections.Concurrent;
using Microsoft.AutoGen.Abstractions;
using Moq;
using Xunit;

namespace Microsoft.AutoGen.Agents.Tests;

public class ISubscriptionsGrainTests
{
private readonly Mock<ISubscriptionsGrain> _mockSubscriptionsGrain;

public ISubscriptionsGrainTests()
{
_mockSubscriptionsGrain = new Mock<ISubscriptionsGrain>();
}

[Fact]
public async Task GetSubscriptionsStateAsync_ReturnsCorrectState()
{
// Arrange
var subscriptionsState = new SubscriptionsState
{
SubscriptionsByAgentType = new ConcurrentDictionary<string, List<string>>
{
["topic1"] = ["agentType1"],
["topic2"] = ["agentType2"]
}
};
_mockSubscriptionsGrain.Setup(grain => grain.GetSubscriptionsStateAsync()).ReturnsAsync(subscriptionsState);

// Act
var result = await _mockSubscriptionsGrain.Object.GetSubscriptionsStateAsync();

// Assert
Assert.Equal(2, result.SubscriptionsByAgentType.Count);
Assert.Contains("topic1", result.SubscriptionsByAgentType.Keys);
Assert.Contains("topic2", result.SubscriptionsByAgentType.Keys);
}

[Fact]
public async Task GetSubscriptions_ReturnsAllSubscriptions_WhenAgentTypeIsNull()
{
// Arrange
var subscriptions = new ConcurrentDictionary<string, List<string>>();
subscriptions.TryAdd("topic1", new List<string> { "agentType1" });
subscriptions.TryAdd("topic2", new List<string> { "agentType2" });
_mockSubscriptionsGrain.Setup(grain => grain.GetSubscriptionsByAgentTypeAsync(null)).ReturnsAsync(subscriptions);

// Act
var result = await _mockSubscriptionsGrain.Object.GetSubscriptionsByAgentTypeAsync();

// Assert
Assert.Equal(2, result.Count);
Assert.Contains("topic1", result.Keys);
Assert.Contains("topic2", result.Keys);
}

[Fact]
public async Task GetSubscriptions_ReturnsFilteredSubscriptions_WhenAgentTypeIsNotNull()
{
// Arrange
var subscriptions = new ConcurrentDictionary<string, List<string>>();
subscriptions.TryAdd("topic1", new List<string> { "agentType1" });
_mockSubscriptionsGrain.Setup(grain => grain.GetSubscriptionsByAgentTypeAsync("agentType1")).ReturnsAsync(subscriptions);

// Act
var result = await _mockSubscriptionsGrain.Object.GetSubscriptionsByAgentTypeAsync("agentType1");

// Assert
Assert.Single(result);
Assert.Contains("topic1", result.Keys);
}

[Fact]
public async Task SubscribeAsync_AddsSubscription()
{
// Arrange
_mockSubscriptionsGrain.Setup(grain => grain.SubscribeAsync("agentType1", "topic1")).Returns(ValueTask.CompletedTask);

// Act
await _mockSubscriptionsGrain.Object.SubscribeAsync("agentType1", "topic1");

// Assert
_mockSubscriptionsGrain.Verify(grain => grain.SubscribeAsync("agentType1", "topic1"), Times.Once);
}

[Fact]
public async Task UnsubscribeAsync_RemovesSubscription()
{
// Arrange
_mockSubscriptionsGrain.Setup(grain => grain.UnsubscribeAsync("agentType1", "topic1")).Returns(ValueTask.CompletedTask);

// Act
await _mockSubscriptionsGrain.Object.UnsubscribeAsync("agentType1", "topic1");

// Assert
_mockSubscriptionsGrain.Verify(grain => grain.UnsubscribeAsync("agentType1", "topic1"), Times.Once);
}
}
Loading