diff --git a/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs b/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs index ab265fc910..8dd2263351 100644 --- a/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs +++ b/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs @@ -49,6 +49,11 @@ public enum EnumerationResult /// Error = 1 << 5, + /// + /// Enumeration was canceled. + /// + Canceled = 1 << 6, + /// /// This result indicates that enumeration has completed and that no further results will be produced. /// @@ -244,26 +249,26 @@ public async ValueTask MoveNextAsync() (EnumerationResult Status, object Value) result; while (true) { - if (_cancellationToken.IsCancellationRequested) - { - _current = default; - return false; - } + _cancellationToken.ThrowIfCancellationRequested(); if (!_initialized) { - result = await _target.StartEnumeration(_requestId, _request); + result = await _target.StartEnumeration(_requestId, _request).AsTask().WaitAsync(_cancellationToken); _initialized = true; } else { - result = await _target.MoveNext(_requestId); + result = await _target.MoveNext(_requestId).AsTask().WaitAsync(_cancellationToken); } if (result.Status is EnumerationResult.Error) { ExceptionDispatchInfo.Capture((Exception)result.Value).Throw(); } + else if (result.Status is EnumerationResult.Canceled) + { + throw new OperationCanceledException(); + } if (result.Status is not EnumerationResult.Heartbeat) { @@ -274,7 +279,7 @@ public async ValueTask MoveNextAsync() if (result.Status is EnumerationResult.MissingEnumeratorError) { throw new EnumerationAbortedException("Enumeration aborted: the remote target does not have a record of this enumerator." - + " This likely indicates that the remote grain was deactivated since enumeration begun."); + + " This likely indicates that the remote grain was deactivated since enumeration begun or that the enumerator was idle for longer than the expiration period."); } Debug.Assert((result.Status & (EnumerationResult.Element | EnumerationResult.Batch | EnumerationResult.Completed)) != 0); diff --git a/src/Orleans.Core/Configuration/Options/MessagingOptions.cs b/src/Orleans.Core/Configuration/Options/MessagingOptions.cs index fdc47a9eac..07e8e00515 100644 --- a/src/Orleans.Core/Configuration/Options/MessagingOptions.cs +++ b/src/Orleans.Core/Configuration/Options/MessagingOptions.cs @@ -55,5 +55,10 @@ public TimeSpan ResponseTimeout /// /// The maximum message body size is 100 MB by default. public int MaxMessageBodySize { get; set; } = 100 * 1024 * 1024; + + /// + /// Gets the response timeout underlying the property, without debugger checks. + /// + internal TimeSpan ConfiguredResponseTimeout => _responseTimeout; } } diff --git a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs index 49a9661282..1077f5541f 100644 --- a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs +++ b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Runtime.CompilerServices; -using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; @@ -19,32 +18,42 @@ namespace Orleans.Runtime; /// internal sealed class AsyncEnumerableGrainExtension : IAsyncEnumerableGrainExtension, IAsyncDisposable, IDisposable { - private const long EnumeratorExpirationMilliseconds = 10_000; + private static readonly DiagnosticListener DiagnosticListener = new("Orleans.Runtime.AsyncEnumerableGrainExtension"); private readonly Dictionary _enumerators = []; - private readonly IGrainContext _grainContext; + private readonly ILogger _logger; private readonly MessagingOptions _messagingOptions; - private readonly IDisposable _timer; + + // Internal for testing + internal IGrainTimer Timer { get; } + internal IGrainContext GrainContext { get; } /// /// Initializes a new instance. /// /// The grain which this extension is attached to. - public AsyncEnumerableGrainExtension(IGrainContext grainContext, IOptions messagingOptions) + public AsyncEnumerableGrainExtension( + IGrainContext grainContext, + IOptions messagingOptions, + ILogger logger) { - _grainContext = grainContext; + _logger = logger; + GrainContext = grainContext; + _messagingOptions = messagingOptions.Value; - var registry = _grainContext.GetComponent(); - _timer = registry.RegisterGrainTimer( - _grainContext, + var registry = GrainContext.GetComponent(); + var cleanupPeriod = messagingOptions.Value.ResponseTimeout; + Timer = registry.RegisterGrainTimer( + GrainContext, static async (state, cancellationToken) => await state.RemoveExpiredAsync(cancellationToken), this, new() { - DueTime = TimeSpan.FromSeconds(EnumeratorExpirationMilliseconds), - Period = TimeSpan.FromSeconds(EnumeratorExpirationMilliseconds), + DueTime = cleanupPeriod, + Period = cleanupPeriod, Interleave = true, KeepAlive = false }); + OnAsyncEnumeratorGrainExtensionCreated(this); } /// @@ -55,12 +64,23 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) List toRemove = default; foreach (var (requestId, state) in _enumerators) { - if (state.LastSeenTimer.ElapsedMilliseconds > EnumeratorExpirationMilliseconds - && state.MoveNextTask is null or { IsCompleted: true }) + if (MarkAndCheck(requestId)) { toRemove ??= []; toRemove.Add(requestId); } + + bool MarkAndCheck(Guid requestId) + { + ref var state = ref CollectionsMarshal.GetValueRefOrNullRef(_enumerators, requestId); + if (Unsafe.IsNullRef(ref state)) + { + return false; + } + + // Returns true if no flags were set. + return state.ClearSeen(); + } } List tasks = default; @@ -81,12 +101,14 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) { await Task.WhenAll(tasks).WaitAsync(cancellationToken); } + + OnEnumeratorCleanupCompleted(this); } /// public ValueTask<(EnumerationResult Status, object Value)> StartEnumeration(Guid requestId, [Immutable] IAsyncEnumerableRequest request) { - request.SetTarget(_grainContext); + request.SetTarget(GrainContext); var enumerable = request.InvokeImplementation(); ref var entry = ref CollectionsMarshal.GetValueRefOrAddDefault(_enumerators, requestId, out bool exists); if (exists) @@ -97,11 +119,10 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) var cts = new CancellationTokenSource(); var enumerator = enumerable.GetAsyncEnumerator(cts.Token); entry.Enumerator = enumerator; - entry.LastSeenTimer.Restart(); entry.MaxBatchSize = request.MaxBatchSize; entry.CancellationTokenSource = cts; Debug.Assert(entry.MaxBatchSize > 0, "Max batch size must be positive."); - return MoveNextAsync(ref entry, requestId, enumerator); + return MoveNextCore(ref entry, requestId, enumerator); static ValueTask<(EnumerationResult Status, object Value)> ThrowAlreadyExists() => ValueTask.FromException<(EnumerationResult Status, object Value)>(new InvalidOperationException("An enumerator with the same id already exists.")); } @@ -115,27 +136,28 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) return new((EnumerationResult.MissingEnumeratorError, default)); } - entry.LastSeenTimer.Restart(); if (entry.Enumerator is not IAsyncEnumerator typedEnumerator) { throw new InvalidCastException("Attempted to access an enumerator of the wrong type."); } - return MoveNextAsync(ref entry, requestId, typedEnumerator); + return MoveNextCore(ref entry, requestId, typedEnumerator); } - private ValueTask<(EnumerationResult Status, object Value)> MoveNextAsync( + private ValueTask<(EnumerationResult Status, object Value)> MoveNextCore( ref EnumeratorState entry, Guid requestId, IAsyncEnumerator typedEnumerator) { Debug.Assert(entry.MaxBatchSize > 0, "Max batch size must be positive."); + entry.SetSeen(); + try { + var currentBatchSize = 0; if (entry.MoveNextTask is null) { ValueTask moveNextValueTask; - var currentBatchSize = 0; object result = null; do { @@ -168,14 +190,14 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) // Completed successfully, possibly with some final elements. if (currentBatchSize == 0) { - return OnComplete(requestId, typedEnumerator); + return OnTerminateAsync(requestId, EnumerationResult.Completed, default); } else if (currentBatchSize == 1) { - return new((EnumerationResult.CompletedWithElement, result)); + return OnTerminateAsync(requestId, EnumerationResult.CompletedWithElement, result); } - return new((EnumerationResult.CompletedWithBatch, result)); + return OnTerminateAsync(requestId, EnumerationResult.CompletedWithBatch, result); } } else @@ -199,11 +221,13 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) // There are no elements, so wait for the pending operation to complete. } + // Prevent the enumerator from being collected while we are enumerating it. + entry.SetBusy(); return AwaitMoveNextAsync(requestId, typedEnumerator, entry.MoveNextTask); } catch (Exception exception) { - return OnError(requestId, typedEnumerator, exception); + return OnTerminateAsync(requestId, EnumerationResult.Error, exception); } } @@ -211,15 +235,32 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) { try { - // Wait up to half the response timeout for the MoveNextAsync task to complete. - using var cancellation = new CancellationTokenSource(_messagingOptions.ResponseTimeout / 2); + // Wait for either the MoveNextAsync task to complete or the polling timeout to elapse. + var longPollingTimeout = _messagingOptions.ConfiguredResponseTimeout / 2; + await moveNextTask.WaitAsync(longPollingTimeout).SuppressThrowing(); + + // Update the enumerator state to indicate that we are not currently waiting for MoveNextAsync to complete. + // If the MoveNextAsync task completed then clear that now, too. + UpdateEnumeratorState(requestId, clearMoveNextTask: moveNextTask.IsCompleted); + void UpdateEnumeratorState(Guid requestId, bool clearMoveNextTask) + { + ref var state = ref CollectionsMarshal.GetValueRefOrNullRef(_enumerators, requestId); + if (Unsafe.IsNullRef(ref state)) + { + return; + } + + state.ClearBusy(); + if (clearMoveNextTask) + { + state.MoveNextTask = null; + } + } - // Wait for either the MoveNextAsync task to complete or the cancellation token to be cancelled. - await moveNextTask.WaitAsync(cancellation.Token).SuppressThrowing(); if (moveNextTask.IsCompletedSuccessfully) { - OnMoveNext(requestId); var hasValue = moveNextTask.GetAwaiter().GetResult(); + if (hasValue) { return (EnumerationResult.Element, typedEnumerator.Current); @@ -227,26 +268,28 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) else { await RemoveEnumeratorAsync(requestId); - await typedEnumerator.DisposeAsync(); return (EnumerationResult.Completed, default); } } + else if (moveNextTask.IsCanceled) + { + await RemoveEnumeratorAsync(requestId); + return (EnumerationResult.Canceled, default); + } else if (moveNextTask.Exception is { } moveNextException) { // Completed, but not successfully. var exception = moveNextException.InnerExceptions.Count == 1 ? moveNextException.InnerException : moveNextException; await RemoveEnumeratorAsync(requestId); - await typedEnumerator.DisposeAsync(); return (EnumerationResult.Error, exception); } return (EnumerationResult.Heartbeat, default); } - catch + catch (Exception exception) { await RemoveEnumeratorAsync(requestId); - await typedEnumerator.DisposeAsync(); - throw; + return (EnumerationResult.Error, exception); } } @@ -258,31 +301,12 @@ private async ValueTask RemoveEnumeratorAsync(Guid requestId) } } - private async ValueTask<(EnumerationResult Status, object Value)> OnComplete(Guid requestId, IAsyncEnumerator enumerator) + private async ValueTask<(EnumerationResult Status, object Value)> OnTerminateAsync(Guid requestId, EnumerationResult status, object value) { await RemoveEnumeratorAsync(requestId); - return (EnumerationResult.Completed, default); + return (status, value); } - private async ValueTask<(EnumerationResult Status, object Value)> OnError(Guid requestId, IAsyncEnumerator enumerator, Exception exception) - { - await RemoveEnumeratorAsync(requestId); - ExceptionDispatchInfo.Throw(exception); - return default; - } - - private void OnMoveNext(Guid requestId) - { - ref var state = ref CollectionsMarshal.GetValueRefOrNullRef(_enumerators, requestId); - if (Unsafe.IsNullRef(ref state)) - { - return; - } - - state.LastSeenTimer.Restart(); - state.MoveNextTask = null; - } - /// public async ValueTask DisposeAsync() { @@ -297,7 +321,7 @@ public async ValueTask DisposeAsync() } } - _timer.Dispose(); + Timer.Dispose(); } private async ValueTask DisposeEnumeratorAsync(EnumeratorState enumerator) @@ -308,40 +332,68 @@ private async ValueTask DisposeEnumeratorAsync(EnumeratorState enumerator) } catch (Exception exception) { - var logger = _grainContext.GetComponent(); - logger?.LogWarning(exception, "Error cancelling enumerator."); + _logger.LogWarning(exception, "Error cancelling enumerator."); } try { + using var cts = new CancellationTokenSource(_messagingOptions.ResponseTimeout); if (enumerator.MoveNextTask is { } task) { - if (enumerator.Enumerator is { } value) - { - await task.SuppressThrowing(); - await value.DisposeAsync(); - } + await task.WaitAsync(cts.Token).SuppressThrowing(); + } + + if (enumerator.MoveNextTask is null or { IsCompleted: true } && enumerator.Enumerator is { } value) + { + await value.DisposeAsync().AsTask().WaitAsync(cts.Token).SuppressThrowing(); } } catch (Exception exception) { - var logger = _grainContext.GetComponent(); - logger?.LogWarning(exception, "Error disposing enumerator."); + _logger.LogWarning(exception, "Error disposing enumerator."); } } /// public void Dispose() { - _timer.Dispose(); + Timer.Dispose(); + } + + private static void OnAsyncEnumeratorGrainExtensionCreated(AsyncEnumerableGrainExtension extension) + { + if (DiagnosticListener.IsEnabled()) + { + DiagnosticListener.Write(nameof(OnAsyncEnumeratorGrainExtensionCreated), extension); + } + } + + private static void OnEnumeratorCleanupCompleted(AsyncEnumerableGrainExtension extension) + { + if (DiagnosticListener.IsEnabled()) + { + DiagnosticListener.Write(nameof(OnEnumeratorCleanupCompleted), extension); + } } private struct EnumeratorState { + private const int SeenFlag = 0x01; + private const int BusyFlag = 0x10; + private int _flags; public IAsyncDisposable Enumerator; public Task MoveNextTask; - public CoarseStopwatch LastSeenTimer; public int MaxBatchSize; internal CancellationTokenSource CancellationTokenSource; + public void SetSeen() => _flags |= SeenFlag; + public void SetBusy() => _flags |= BusyFlag | SeenFlag; + public void ClearBusy() => _flags = SeenFlag; // Clear the 'Busy' flag, but set the 'Seen' flag. + public bool ClearSeen() + { + // Clear the 'Seen' flag and check if any flags were set previously. + var isExpired = _flags == 0; + _flags &= ~SeenFlag; + return isExpired; + } } } diff --git a/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs b/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs index 9d3bc8510c..e1c7d8fab2 100644 --- a/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs +++ b/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs @@ -1,6 +1,7 @@ -using Microsoft.Extensions.Logging; +#nullable enable +using System.Diagnostics; +using Microsoft.Extensions.Logging; using Orleans.Internal; -using Orleans.Runtime; using TestExtensions; using UnitTests.GrainInterfaces; using Xunit; @@ -83,6 +84,47 @@ public async Task ObservableGrain_AsyncEnumerable_Throws(int errorIndex, bool wa Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync))); } + [Theory, TestCategory("BVT"), TestCategory("Observable")] + [InlineData(0, false)] + [InlineData(0, true)] + [InlineData(1, false)] + [InlineData(1, true)] + [InlineData(9, false)] + [InlineData(9, true)] + [InlineData(10, false)] + [InlineData(10, true)] + [InlineData(11, false)] + [InlineData(11, true)] + public async Task ObservableGrain_AsyncEnumerable_Cancellation(int errorIndex, bool waitAfterYield) + { + // This special error message is interpreted to indicate that cancellation + // should occur when the index is reached. + const string ErrorMessage = "cancel"; + var grain = GrainFactory.GetGrain(Guid.NewGuid()); + + var values = new List(); + try + { + await foreach (var entry in grain.GetValuesWithError(errorIndex, waitAfterYield, ErrorMessage).WithBatchSize(10)) + { + values.Add(entry); + Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry); + } + } + catch (OperationCanceledException oce) + { + var expectedMessage = new OperationCanceledException().Message; + Assert.Equal(expectedMessage, oce.Message); + } + + Assert.Equal(errorIndex, values.Count); + + // Check that the enumerator is disposed + var grainCalls = await grain.GetIncomingCalls(); + Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync))); + } + + [Fact, TestCategory("BVT"), TestCategory("Observable")] public async Task ObservableGrain_AsyncEnumerable_Batch() { @@ -196,15 +238,24 @@ public async Task ObservableGrain_AsyncEnumerable_WithCancellation() var values = new List(); using var cts = new CancellationTokenSource(); - await foreach (var entry in grain.GetValues().WithCancellation(cts.Token)) + try { - values.Add(entry); - if (values.Count == 3) + await foreach (var entry in grain.GetValues().WithCancellation(cts.Token)) { - cts.Cancel(); + values.Add(entry); + if (values.Count == 3) + { + cts.Cancel(); + } + + Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry); } - Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry); + Assert.Fail("Expected an exception to be thrown"); + } + catch (OperationCanceledException) + { + // Expected } Assert.Equal(3, values.Count); @@ -249,6 +300,100 @@ public async Task ObservableGrain_AsyncEnumerable_SlowProducer() Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync))); } + [Fact, TestCategory("BVT"), TestCategory("Observable")] + public async Task ObservableGrain_AsyncEnumerable_SlowConsumer() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var cleanupInterval = TimeSpan.FromMilliseconds(100); + var grain = GrainFactory.GetGrain(Guid.NewGuid()); + using var listener = new AsyncEnumerableGrainExtensionListener(grain.GetGrainId(), cleanupInterval); + + var producer = Task.Run(async () => + { + foreach (var value in Enumerable.Range(0, 5)) + { + await grain.OnNext(value.ToString()); + } + + await grain.Complete(); + }); + + var values = new List(); + await foreach (var entry in grain.GetValues().WithBatchSize(1)) + { + values.Add(entry); + + // Sleep for 1 cycle before reading the next value. + // The enumerator should not be cleaned up. + var initialCleanupCount = listener.CleanupCount; + while (listener.CleanupCount == initialCleanupCount) + { + await Task.Delay(cleanupInterval / 10, cts.Token); + } + + Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry); + } + + Assert.Equal(5, values.Count); + + // Check that the enumerator is disposed + var grainCalls = await grain.GetIncomingCalls(); + Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync))); + } + + [Fact, TestCategory("BVT"), TestCategory("Observable")] + public async Task ObservableGrain_AsyncEnumerable_SlowConsumer_Evicted() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var cleanupInterval = TimeSpan.FromMilliseconds(100); + var grain = GrainFactory.GetGrain(Guid.NewGuid()); + using var listener = new AsyncEnumerableGrainExtensionListener(grain.GetGrainId(), cleanupInterval); + + var producer = Task.Run(async () => + { + foreach (var value in Enumerable.Range(0, 5)) + { + await grain.OnNext(value.ToString()); + } + + await grain.Complete(); + }); + + var values = new List(); + try + { + await foreach (var entry in grain.GetValues().WithBatchSize(1)) + { + values.Add(entry); + + // After the 3rd iteration, sleep for longer than the cleanup duration + // and wait for the enumerator to be cleaned up. + if (values.Count >= 3) + { + var initialCleanupCount = listener.CleanupCount; + while (listener.CleanupCount < initialCleanupCount + 2) + { + await Task.Delay(cleanupInterval, cts.Token); + } + } + + Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry); + } + + Assert.Fail("Expected an exception to be thrown"); + } + catch (EnumerationAbortedException ex) + { + Assert.Contains("the remote target does not have a record of this enumerator", ex.Message); + } + + Assert.Equal(3, values.Count); + + // Check that the enumerator is disposed + var grainCalls = await grain.GetIncomingCalls(); + Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync))); + } + [Fact, TestCategory("BVT"), TestCategory("Observable")] public async Task ObservableGrain_AsyncEnumerable_Deactivate() { @@ -277,4 +422,65 @@ await Assert.ThrowsAsync(async () => Assert.Equal(2, values.Count); } + + private sealed class AsyncEnumerableGrainExtensionListener : IObserver>, IObserver, IDisposable + { + private readonly IDisposable _allListenersSubscription; + private readonly GrainId _targetGrainId; + private readonly TimeSpan _enumeratorCleanupInterval; + private IDisposable? _instanceSubscription; + + public AsyncEnumerableGrainExtensionListener(GrainId targetGrainId, TimeSpan enumeratorCleanupInterval) + { + _allListenersSubscription = DiagnosticListener.AllListeners.Subscribe(this); + _targetGrainId = targetGrainId; + _enumeratorCleanupInterval = enumeratorCleanupInterval; + } + + public int CleanupCount { get; private set; } + + void IObserver>.OnCompleted() + { + _instanceSubscription?.Dispose(); + } + + void IObserver>.OnError(Exception error) + { + } + + void IObserver>.OnNext(KeyValuePair value) + { + var extension = (AsyncEnumerableGrainExtension)value.Value!; + if (extension.GrainContext.GrainId != _targetGrainId) + { + return; + } + + if (value.Key == "OnAsyncEnumeratorGrainExtensionCreated") + { + extension.Timer.Change(_enumeratorCleanupInterval, _enumeratorCleanupInterval); + } + + if (value.Key == "OnEnumeratorCleanupCompleted") + { + ++CleanupCount; + } + } + + void IObserver.OnCompleted() { } + void IObserver.OnError(Exception error) { } + void IObserver.OnNext(DiagnosticListener value) + { + if (value.Name == "Orleans.Runtime.AsyncEnumerableGrainExtension") + { + _instanceSubscription = value.Subscribe(this); + } + } + + public void Dispose() + { + _allListenersSubscription.Dispose(); + _instanceSubscription?.Dispose(); + } + } } diff --git a/test/Grains/TestGrainInterfaces/IObservableGrain.cs b/test/Grains/TestGrainInterfaces/IObservableGrain.cs index c2f05c5d1c..eac7c2e0c3 100644 --- a/test/Grains/TestGrainInterfaces/IObservableGrain.cs +++ b/test/Grains/TestGrainInterfaces/IObservableGrain.cs @@ -1,4 +1,4 @@ -namespace UnitTests.GrainInterfaces +namespace UnitTests.GrainInterfaces { /// /// A grain which returns IAsyncEnumerable diff --git a/test/Grains/TestInternalGrains/ObservableGrain.cs b/test/Grains/TestInternalGrains/ObservableGrain.cs index 8e6ce36679..efac022bd1 100644 --- a/test/Grains/TestInternalGrains/ObservableGrain.cs +++ b/test/Grains/TestInternalGrains/ObservableGrain.cs @@ -1,4 +1,4 @@ -using System.Threading.Channels; +using System.Threading.Channels; using UnitTests.GrainInterfaces; namespace UnitTests.Grains @@ -17,6 +17,11 @@ public async IAsyncEnumerable GetValuesWithError(int errorIndex, bool waitA { if (i == errorIndex) { + if (errorMessage == "cancel") + { + throw new OperationCanceledException(errorMessage); + } + throw new InvalidOperationException(errorMessage); }