diff --git a/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs b/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs
index ab265fc910..8dd2263351 100644
--- a/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs
+++ b/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs
@@ -49,6 +49,11 @@ public enum EnumerationResult
///
Error = 1 << 5,
+ ///
+ /// Enumeration was canceled.
+ ///
+ Canceled = 1 << 6,
+
///
/// This result indicates that enumeration has completed and that no further results will be produced.
///
@@ -244,26 +249,26 @@ public async ValueTask MoveNextAsync()
(EnumerationResult Status, object Value) result;
while (true)
{
- if (_cancellationToken.IsCancellationRequested)
- {
- _current = default;
- return false;
- }
+ _cancellationToken.ThrowIfCancellationRequested();
if (!_initialized)
{
- result = await _target.StartEnumeration(_requestId, _request);
+ result = await _target.StartEnumeration(_requestId, _request).AsTask().WaitAsync(_cancellationToken);
_initialized = true;
}
else
{
- result = await _target.MoveNext(_requestId);
+ result = await _target.MoveNext(_requestId).AsTask().WaitAsync(_cancellationToken);
}
if (result.Status is EnumerationResult.Error)
{
ExceptionDispatchInfo.Capture((Exception)result.Value).Throw();
}
+ else if (result.Status is EnumerationResult.Canceled)
+ {
+ throw new OperationCanceledException();
+ }
if (result.Status is not EnumerationResult.Heartbeat)
{
@@ -274,7 +279,7 @@ public async ValueTask MoveNextAsync()
if (result.Status is EnumerationResult.MissingEnumeratorError)
{
throw new EnumerationAbortedException("Enumeration aborted: the remote target does not have a record of this enumerator."
- + " This likely indicates that the remote grain was deactivated since enumeration begun.");
+ + " This likely indicates that the remote grain was deactivated since enumeration begun or that the enumerator was idle for longer than the expiration period.");
}
Debug.Assert((result.Status & (EnumerationResult.Element | EnumerationResult.Batch | EnumerationResult.Completed)) != 0);
diff --git a/src/Orleans.Core/Configuration/Options/MessagingOptions.cs b/src/Orleans.Core/Configuration/Options/MessagingOptions.cs
index fdc47a9eac..07e8e00515 100644
--- a/src/Orleans.Core/Configuration/Options/MessagingOptions.cs
+++ b/src/Orleans.Core/Configuration/Options/MessagingOptions.cs
@@ -55,5 +55,10 @@ public TimeSpan ResponseTimeout
///
/// The maximum message body size is 100 MB by default.
public int MaxMessageBodySize { get; set; } = 100 * 1024 * 1024;
+
+ ///
+ /// Gets the response timeout underlying the property, without debugger checks.
+ ///
+ internal TimeSpan ConfiguredResponseTimeout => _responseTimeout;
}
}
diff --git a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs
index 49a9661282..1077f5541f 100644
--- a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs
+++ b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs
@@ -2,7 +2,6 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
-using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
@@ -19,32 +18,42 @@ namespace Orleans.Runtime;
///
internal sealed class AsyncEnumerableGrainExtension : IAsyncEnumerableGrainExtension, IAsyncDisposable, IDisposable
{
- private const long EnumeratorExpirationMilliseconds = 10_000;
+ private static readonly DiagnosticListener DiagnosticListener = new("Orleans.Runtime.AsyncEnumerableGrainExtension");
private readonly Dictionary _enumerators = [];
- private readonly IGrainContext _grainContext;
+ private readonly ILogger _logger;
private readonly MessagingOptions _messagingOptions;
- private readonly IDisposable _timer;
+
+ // Internal for testing
+ internal IGrainTimer Timer { get; }
+ internal IGrainContext GrainContext { get; }
///
/// Initializes a new instance.
///
/// The grain which this extension is attached to.
- public AsyncEnumerableGrainExtension(IGrainContext grainContext, IOptions messagingOptions)
+ public AsyncEnumerableGrainExtension(
+ IGrainContext grainContext,
+ IOptions messagingOptions,
+ ILogger logger)
{
- _grainContext = grainContext;
+ _logger = logger;
+ GrainContext = grainContext;
+
_messagingOptions = messagingOptions.Value;
- var registry = _grainContext.GetComponent();
- _timer = registry.RegisterGrainTimer(
- _grainContext,
+ var registry = GrainContext.GetComponent();
+ var cleanupPeriod = messagingOptions.Value.ResponseTimeout;
+ Timer = registry.RegisterGrainTimer(
+ GrainContext,
static async (state, cancellationToken) => await state.RemoveExpiredAsync(cancellationToken),
this,
new()
{
- DueTime = TimeSpan.FromSeconds(EnumeratorExpirationMilliseconds),
- Period = TimeSpan.FromSeconds(EnumeratorExpirationMilliseconds),
+ DueTime = cleanupPeriod,
+ Period = cleanupPeriod,
Interleave = true,
KeepAlive = false
});
+ OnAsyncEnumeratorGrainExtensionCreated(this);
}
///
@@ -55,12 +64,23 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
List toRemove = default;
foreach (var (requestId, state) in _enumerators)
{
- if (state.LastSeenTimer.ElapsedMilliseconds > EnumeratorExpirationMilliseconds
- && state.MoveNextTask is null or { IsCompleted: true })
+ if (MarkAndCheck(requestId))
{
toRemove ??= [];
toRemove.Add(requestId);
}
+
+ bool MarkAndCheck(Guid requestId)
+ {
+ ref var state = ref CollectionsMarshal.GetValueRefOrNullRef(_enumerators, requestId);
+ if (Unsafe.IsNullRef(ref state))
+ {
+ return false;
+ }
+
+ // Returns true if no flags were set.
+ return state.ClearSeen();
+ }
}
List tasks = default;
@@ -81,12 +101,14 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
await Task.WhenAll(tasks).WaitAsync(cancellationToken);
}
+
+ OnEnumeratorCleanupCompleted(this);
}
///
public ValueTask<(EnumerationResult Status, object Value)> StartEnumeration(Guid requestId, [Immutable] IAsyncEnumerableRequest request)
{
- request.SetTarget(_grainContext);
+ request.SetTarget(GrainContext);
var enumerable = request.InvokeImplementation();
ref var entry = ref CollectionsMarshal.GetValueRefOrAddDefault(_enumerators, requestId, out bool exists);
if (exists)
@@ -97,11 +119,10 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
var cts = new CancellationTokenSource();
var enumerator = enumerable.GetAsyncEnumerator(cts.Token);
entry.Enumerator = enumerator;
- entry.LastSeenTimer.Restart();
entry.MaxBatchSize = request.MaxBatchSize;
entry.CancellationTokenSource = cts;
Debug.Assert(entry.MaxBatchSize > 0, "Max batch size must be positive.");
- return MoveNextAsync(ref entry, requestId, enumerator);
+ return MoveNextCore(ref entry, requestId, enumerator);
static ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists() => ValueTask.FromException<(EnumerationResult Status, object Value)>(new InvalidOperationException("An enumerator with the same id already exists."));
}
@@ -115,27 +136,28 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
return new((EnumerationResult.MissingEnumeratorError, default));
}
- entry.LastSeenTimer.Restart();
if (entry.Enumerator is not IAsyncEnumerator typedEnumerator)
{
throw new InvalidCastException("Attempted to access an enumerator of the wrong type.");
}
- return MoveNextAsync(ref entry, requestId, typedEnumerator);
+ return MoveNextCore(ref entry, requestId, typedEnumerator);
}
- private ValueTask<(EnumerationResult Status, object Value)> MoveNextAsync(
+ private ValueTask<(EnumerationResult Status, object Value)> MoveNextCore(
ref EnumeratorState entry,
Guid requestId,
IAsyncEnumerator typedEnumerator)
{
Debug.Assert(entry.MaxBatchSize > 0, "Max batch size must be positive.");
+ entry.SetSeen();
+
try
{
+ var currentBatchSize = 0;
if (entry.MoveNextTask is null)
{
ValueTask moveNextValueTask;
- var currentBatchSize = 0;
object result = null;
do
{
@@ -168,14 +190,14 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
// Completed successfully, possibly with some final elements.
if (currentBatchSize == 0)
{
- return OnComplete(requestId, typedEnumerator);
+ return OnTerminateAsync(requestId, EnumerationResult.Completed, default);
}
else if (currentBatchSize == 1)
{
- return new((EnumerationResult.CompletedWithElement, result));
+ return OnTerminateAsync(requestId, EnumerationResult.CompletedWithElement, result);
}
- return new((EnumerationResult.CompletedWithBatch, result));
+ return OnTerminateAsync(requestId, EnumerationResult.CompletedWithBatch, result);
}
}
else
@@ -199,11 +221,13 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
// There are no elements, so wait for the pending operation to complete.
}
+ // Prevent the enumerator from being collected while we are enumerating it.
+ entry.SetBusy();
return AwaitMoveNextAsync(requestId, typedEnumerator, entry.MoveNextTask);
}
catch (Exception exception)
{
- return OnError(requestId, typedEnumerator, exception);
+ return OnTerminateAsync(requestId, EnumerationResult.Error, exception);
}
}
@@ -211,15 +235,32 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
{
try
{
- // Wait up to half the response timeout for the MoveNextAsync task to complete.
- using var cancellation = new CancellationTokenSource(_messagingOptions.ResponseTimeout / 2);
+ // Wait for either the MoveNextAsync task to complete or the polling timeout to elapse.
+ var longPollingTimeout = _messagingOptions.ConfiguredResponseTimeout / 2;
+ await moveNextTask.WaitAsync(longPollingTimeout).SuppressThrowing();
+
+ // Update the enumerator state to indicate that we are not currently waiting for MoveNextAsync to complete.
+ // If the MoveNextAsync task completed then clear that now, too.
+ UpdateEnumeratorState(requestId, clearMoveNextTask: moveNextTask.IsCompleted);
+ void UpdateEnumeratorState(Guid requestId, bool clearMoveNextTask)
+ {
+ ref var state = ref CollectionsMarshal.GetValueRefOrNullRef(_enumerators, requestId);
+ if (Unsafe.IsNullRef(ref state))
+ {
+ return;
+ }
+
+ state.ClearBusy();
+ if (clearMoveNextTask)
+ {
+ state.MoveNextTask = null;
+ }
+ }
- // Wait for either the MoveNextAsync task to complete or the cancellation token to be cancelled.
- await moveNextTask.WaitAsync(cancellation.Token).SuppressThrowing();
if (moveNextTask.IsCompletedSuccessfully)
{
- OnMoveNext(requestId);
var hasValue = moveNextTask.GetAwaiter().GetResult();
+
if (hasValue)
{
return (EnumerationResult.Element, typedEnumerator.Current);
@@ -227,26 +268,28 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken)
else
{
await RemoveEnumeratorAsync(requestId);
- await typedEnumerator.DisposeAsync();
return (EnumerationResult.Completed, default);
}
}
+ else if (moveNextTask.IsCanceled)
+ {
+ await RemoveEnumeratorAsync(requestId);
+ return (EnumerationResult.Canceled, default);
+ }
else if (moveNextTask.Exception is { } moveNextException)
{
// Completed, but not successfully.
var exception = moveNextException.InnerExceptions.Count == 1 ? moveNextException.InnerException : moveNextException;
await RemoveEnumeratorAsync(requestId);
- await typedEnumerator.DisposeAsync();
return (EnumerationResult.Error, exception);
}
return (EnumerationResult.Heartbeat, default);
}
- catch
+ catch (Exception exception)
{
await RemoveEnumeratorAsync(requestId);
- await typedEnumerator.DisposeAsync();
- throw;
+ return (EnumerationResult.Error, exception);
}
}
@@ -258,31 +301,12 @@ private async ValueTask RemoveEnumeratorAsync(Guid requestId)
}
}
- private async ValueTask<(EnumerationResult Status, object Value)> OnComplete(Guid requestId, IAsyncEnumerator enumerator)
+ private async ValueTask<(EnumerationResult Status, object Value)> OnTerminateAsync(Guid requestId, EnumerationResult status, object value)
{
await RemoveEnumeratorAsync(requestId);
- return (EnumerationResult.Completed, default);
+ return (status, value);
}
- private async ValueTask<(EnumerationResult Status, object Value)> OnError(Guid requestId, IAsyncEnumerator enumerator, Exception exception)
- {
- await RemoveEnumeratorAsync(requestId);
- ExceptionDispatchInfo.Throw(exception);
- return default;
- }
-
- private void OnMoveNext(Guid requestId)
- {
- ref var state = ref CollectionsMarshal.GetValueRefOrNullRef(_enumerators, requestId);
- if (Unsafe.IsNullRef(ref state))
- {
- return;
- }
-
- state.LastSeenTimer.Restart();
- state.MoveNextTask = null;
- }
-
///
public async ValueTask DisposeAsync()
{
@@ -297,7 +321,7 @@ public async ValueTask DisposeAsync()
}
}
- _timer.Dispose();
+ Timer.Dispose();
}
private async ValueTask DisposeEnumeratorAsync(EnumeratorState enumerator)
@@ -308,40 +332,68 @@ private async ValueTask DisposeEnumeratorAsync(EnumeratorState enumerator)
}
catch (Exception exception)
{
- var logger = _grainContext.GetComponent();
- logger?.LogWarning(exception, "Error cancelling enumerator.");
+ _logger.LogWarning(exception, "Error cancelling enumerator.");
}
try
{
+ using var cts = new CancellationTokenSource(_messagingOptions.ResponseTimeout);
if (enumerator.MoveNextTask is { } task)
{
- if (enumerator.Enumerator is { } value)
- {
- await task.SuppressThrowing();
- await value.DisposeAsync();
- }
+ await task.WaitAsync(cts.Token).SuppressThrowing();
+ }
+
+ if (enumerator.MoveNextTask is null or { IsCompleted: true } && enumerator.Enumerator is { } value)
+ {
+ await value.DisposeAsync().AsTask().WaitAsync(cts.Token).SuppressThrowing();
}
}
catch (Exception exception)
{
- var logger = _grainContext.GetComponent();
- logger?.LogWarning(exception, "Error disposing enumerator.");
+ _logger.LogWarning(exception, "Error disposing enumerator.");
}
}
///
public void Dispose()
{
- _timer.Dispose();
+ Timer.Dispose();
+ }
+
+ private static void OnAsyncEnumeratorGrainExtensionCreated(AsyncEnumerableGrainExtension extension)
+ {
+ if (DiagnosticListener.IsEnabled())
+ {
+ DiagnosticListener.Write(nameof(OnAsyncEnumeratorGrainExtensionCreated), extension);
+ }
+ }
+
+ private static void OnEnumeratorCleanupCompleted(AsyncEnumerableGrainExtension extension)
+ {
+ if (DiagnosticListener.IsEnabled())
+ {
+ DiagnosticListener.Write(nameof(OnEnumeratorCleanupCompleted), extension);
+ }
}
private struct EnumeratorState
{
+ private const int SeenFlag = 0x01;
+ private const int BusyFlag = 0x10;
+ private int _flags;
public IAsyncDisposable Enumerator;
public Task MoveNextTask;
- public CoarseStopwatch LastSeenTimer;
public int MaxBatchSize;
internal CancellationTokenSource CancellationTokenSource;
+ public void SetSeen() => _flags |= SeenFlag;
+ public void SetBusy() => _flags |= BusyFlag | SeenFlag;
+ public void ClearBusy() => _flags = SeenFlag; // Clear the 'Busy' flag, but set the 'Seen' flag.
+ public bool ClearSeen()
+ {
+ // Clear the 'Seen' flag and check if any flags were set previously.
+ var isExpired = _flags == 0;
+ _flags &= ~SeenFlag;
+ return isExpired;
+ }
}
}
diff --git a/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs b/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs
index 9d3bc8510c..e1c7d8fab2 100644
--- a/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs
+++ b/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs
@@ -1,6 +1,7 @@
-using Microsoft.Extensions.Logging;
+#nullable enable
+using System.Diagnostics;
+using Microsoft.Extensions.Logging;
using Orleans.Internal;
-using Orleans.Runtime;
using TestExtensions;
using UnitTests.GrainInterfaces;
using Xunit;
@@ -83,6 +84,47 @@ public async Task ObservableGrain_AsyncEnumerable_Throws(int errorIndex, bool wa
Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync)));
}
+ [Theory, TestCategory("BVT"), TestCategory("Observable")]
+ [InlineData(0, false)]
+ [InlineData(0, true)]
+ [InlineData(1, false)]
+ [InlineData(1, true)]
+ [InlineData(9, false)]
+ [InlineData(9, true)]
+ [InlineData(10, false)]
+ [InlineData(10, true)]
+ [InlineData(11, false)]
+ [InlineData(11, true)]
+ public async Task ObservableGrain_AsyncEnumerable_Cancellation(int errorIndex, bool waitAfterYield)
+ {
+ // This special error message is interpreted to indicate that cancellation
+ // should occur when the index is reached.
+ const string ErrorMessage = "cancel";
+ var grain = GrainFactory.GetGrain(Guid.NewGuid());
+
+ var values = new List();
+ try
+ {
+ await foreach (var entry in grain.GetValuesWithError(errorIndex, waitAfterYield, ErrorMessage).WithBatchSize(10))
+ {
+ values.Add(entry);
+ Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry);
+ }
+ }
+ catch (OperationCanceledException oce)
+ {
+ var expectedMessage = new OperationCanceledException().Message;
+ Assert.Equal(expectedMessage, oce.Message);
+ }
+
+ Assert.Equal(errorIndex, values.Count);
+
+ // Check that the enumerator is disposed
+ var grainCalls = await grain.GetIncomingCalls();
+ Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync)));
+ }
+
+
[Fact, TestCategory("BVT"), TestCategory("Observable")]
public async Task ObservableGrain_AsyncEnumerable_Batch()
{
@@ -196,15 +238,24 @@ public async Task ObservableGrain_AsyncEnumerable_WithCancellation()
var values = new List();
using var cts = new CancellationTokenSource();
- await foreach (var entry in grain.GetValues().WithCancellation(cts.Token))
+ try
{
- values.Add(entry);
- if (values.Count == 3)
+ await foreach (var entry in grain.GetValues().WithCancellation(cts.Token))
{
- cts.Cancel();
+ values.Add(entry);
+ if (values.Count == 3)
+ {
+ cts.Cancel();
+ }
+
+ Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry);
}
- Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry);
+ Assert.Fail("Expected an exception to be thrown");
+ }
+ catch (OperationCanceledException)
+ {
+ // Expected
}
Assert.Equal(3, values.Count);
@@ -249,6 +300,100 @@ public async Task ObservableGrain_AsyncEnumerable_SlowProducer()
Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync)));
}
+ [Fact, TestCategory("BVT"), TestCategory("Observable")]
+ public async Task ObservableGrain_AsyncEnumerable_SlowConsumer()
+ {
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
+ var cleanupInterval = TimeSpan.FromMilliseconds(100);
+ var grain = GrainFactory.GetGrain(Guid.NewGuid());
+ using var listener = new AsyncEnumerableGrainExtensionListener(grain.GetGrainId(), cleanupInterval);
+
+ var producer = Task.Run(async () =>
+ {
+ foreach (var value in Enumerable.Range(0, 5))
+ {
+ await grain.OnNext(value.ToString());
+ }
+
+ await grain.Complete();
+ });
+
+ var values = new List();
+ await foreach (var entry in grain.GetValues().WithBatchSize(1))
+ {
+ values.Add(entry);
+
+ // Sleep for 1 cycle before reading the next value.
+ // The enumerator should not be cleaned up.
+ var initialCleanupCount = listener.CleanupCount;
+ while (listener.CleanupCount == initialCleanupCount)
+ {
+ await Task.Delay(cleanupInterval / 10, cts.Token);
+ }
+
+ Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry);
+ }
+
+ Assert.Equal(5, values.Count);
+
+ // Check that the enumerator is disposed
+ var grainCalls = await grain.GetIncomingCalls();
+ Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync)));
+ }
+
+ [Fact, TestCategory("BVT"), TestCategory("Observable")]
+ public async Task ObservableGrain_AsyncEnumerable_SlowConsumer_Evicted()
+ {
+ using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30));
+ var cleanupInterval = TimeSpan.FromMilliseconds(100);
+ var grain = GrainFactory.GetGrain(Guid.NewGuid());
+ using var listener = new AsyncEnumerableGrainExtensionListener(grain.GetGrainId(), cleanupInterval);
+
+ var producer = Task.Run(async () =>
+ {
+ foreach (var value in Enumerable.Range(0, 5))
+ {
+ await grain.OnNext(value.ToString());
+ }
+
+ await grain.Complete();
+ });
+
+ var values = new List();
+ try
+ {
+ await foreach (var entry in grain.GetValues().WithBatchSize(1))
+ {
+ values.Add(entry);
+
+ // After the 3rd iteration, sleep for longer than the cleanup duration
+ // and wait for the enumerator to be cleaned up.
+ if (values.Count >= 3)
+ {
+ var initialCleanupCount = listener.CleanupCount;
+ while (listener.CleanupCount < initialCleanupCount + 2)
+ {
+ await Task.Delay(cleanupInterval, cts.Token);
+ }
+ }
+
+ Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry);
+ }
+
+ Assert.Fail("Expected an exception to be thrown");
+ }
+ catch (EnumerationAbortedException ex)
+ {
+ Assert.Contains("the remote target does not have a record of this enumerator", ex.Message);
+ }
+
+ Assert.Equal(3, values.Count);
+
+ // Check that the enumerator is disposed
+ var grainCalls = await grain.GetIncomingCalls();
+ Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync)));
+ }
+
[Fact, TestCategory("BVT"), TestCategory("Observable")]
public async Task ObservableGrain_AsyncEnumerable_Deactivate()
{
@@ -277,4 +422,65 @@ await Assert.ThrowsAsync(async () =>
Assert.Equal(2, values.Count);
}
+
+ private sealed class AsyncEnumerableGrainExtensionListener : IObserver>, IObserver, IDisposable
+ {
+ private readonly IDisposable _allListenersSubscription;
+ private readonly GrainId _targetGrainId;
+ private readonly TimeSpan _enumeratorCleanupInterval;
+ private IDisposable? _instanceSubscription;
+
+ public AsyncEnumerableGrainExtensionListener(GrainId targetGrainId, TimeSpan enumeratorCleanupInterval)
+ {
+ _allListenersSubscription = DiagnosticListener.AllListeners.Subscribe(this);
+ _targetGrainId = targetGrainId;
+ _enumeratorCleanupInterval = enumeratorCleanupInterval;
+ }
+
+ public int CleanupCount { get; private set; }
+
+ void IObserver>.OnCompleted()
+ {
+ _instanceSubscription?.Dispose();
+ }
+
+ void IObserver>.OnError(Exception error)
+ {
+ }
+
+ void IObserver>.OnNext(KeyValuePair value)
+ {
+ var extension = (AsyncEnumerableGrainExtension)value.Value!;
+ if (extension.GrainContext.GrainId != _targetGrainId)
+ {
+ return;
+ }
+
+ if (value.Key == "OnAsyncEnumeratorGrainExtensionCreated")
+ {
+ extension.Timer.Change(_enumeratorCleanupInterval, _enumeratorCleanupInterval);
+ }
+
+ if (value.Key == "OnEnumeratorCleanupCompleted")
+ {
+ ++CleanupCount;
+ }
+ }
+
+ void IObserver.OnCompleted() { }
+ void IObserver.OnError(Exception error) { }
+ void IObserver.OnNext(DiagnosticListener value)
+ {
+ if (value.Name == "Orleans.Runtime.AsyncEnumerableGrainExtension")
+ {
+ _instanceSubscription = value.Subscribe(this);
+ }
+ }
+
+ public void Dispose()
+ {
+ _allListenersSubscription.Dispose();
+ _instanceSubscription?.Dispose();
+ }
+ }
}
diff --git a/test/Grains/TestGrainInterfaces/IObservableGrain.cs b/test/Grains/TestGrainInterfaces/IObservableGrain.cs
index c2f05c5d1c..eac7c2e0c3 100644
--- a/test/Grains/TestGrainInterfaces/IObservableGrain.cs
+++ b/test/Grains/TestGrainInterfaces/IObservableGrain.cs
@@ -1,4 +1,4 @@
-namespace UnitTests.GrainInterfaces
+namespace UnitTests.GrainInterfaces
{
///
/// A grain which returns IAsyncEnumerable
diff --git a/test/Grains/TestInternalGrains/ObservableGrain.cs b/test/Grains/TestInternalGrains/ObservableGrain.cs
index 8e6ce36679..efac022bd1 100644
--- a/test/Grains/TestInternalGrains/ObservableGrain.cs
+++ b/test/Grains/TestInternalGrains/ObservableGrain.cs
@@ -1,4 +1,4 @@
-using System.Threading.Channels;
+using System.Threading.Channels;
using UnitTests.GrainInterfaces;
namespace UnitTests.Grains
@@ -17,6 +17,11 @@ public async IAsyncEnumerable GetValuesWithError(int errorIndex, bool waitA
{
if (i == errorIndex)
{
+ if (errorMessage == "cancel")
+ {
+ throw new OperationCanceledException(errorMessage);
+ }
+
throw new InvalidOperationException(errorMessage);
}