diff --git a/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs b/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs index fa6da867c3d..ab265fc910b 100644 --- a/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs +++ b/src/Orleans.Core.Abstractions/Runtime/AsyncEnumerableRequest.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Runtime.ExceptionServices; using System.Runtime.Serialization; using System.Threading; using System.Threading.Tasks; @@ -43,6 +44,11 @@ public enum EnumerationResult /// MissingEnumeratorError = 1 << 4, + /// + /// The attempt to enumerate failed because the enumeration threw an exception. + /// + Error = 1 << 5, + /// /// This result indicates that enumeration has completed and that no further results will be produced. /// @@ -254,6 +260,11 @@ public async ValueTask MoveNextAsync() result = await _target.MoveNext(_requestId); } + if (result.Status is EnumerationResult.Error) + { + ExceptionDispatchInfo.Capture((Exception)result.Value).Throw(); + } + if (result.Status is not EnumerationResult.Heartbeat) { break; diff --git a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs index 71217375c61..49a96612821 100644 --- a/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs +++ b/src/Orleans.Core/Runtime/AsyncEnumerableGrainExtension.cs @@ -216,7 +216,7 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) // Wait for either the MoveNextAsync task to complete or the cancellation token to be cancelled. await moveNextTask.WaitAsync(cancellation.Token).SuppressThrowing(); - if (moveNextTask is {IsCompletedSuccessfully: true }) + if (moveNextTask.IsCompletedSuccessfully) { OnMoveNext(requestId); var hasValue = moveNextTask.GetAwaiter().GetResult(); @@ -231,6 +231,14 @@ private async ValueTask RemoveExpiredAsync(CancellationToken cancellationToken) return (EnumerationResult.Completed, default); } } + else if (moveNextTask.Exception is { } moveNextException) + { + // Completed, but not successfully. + var exception = moveNextException.InnerExceptions.Count == 1 ? moveNextException.InnerException : moveNextException; + await RemoveEnumeratorAsync(requestId); + await typedEnumerator.DisposeAsync(); + return (EnumerationResult.Error, exception); + } return (EnumerationResult.Heartbeat, default); } diff --git a/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs b/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs index 20322a9b816..9d3bc8510c4 100644 --- a/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs +++ b/test/DefaultCluster.Tests/AsyncEnumerableGrainCallTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using Orleans.Internal; using Orleans.Runtime; using TestExtensions; @@ -46,6 +46,43 @@ public async Task ObservableGrain_AsyncEnumerable() Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync))); } + [Theory, TestCategory("BVT"), TestCategory("Observable")] + [InlineData(0, false)] + [InlineData(0, true)] + [InlineData(1, false)] + [InlineData(1, true)] + [InlineData(9, false)] + [InlineData(9, true)] + [InlineData(10, false)] + [InlineData(10, true)] + [InlineData(11, false)] + [InlineData(11, true)] + public async Task ObservableGrain_AsyncEnumerable_Throws(int errorIndex, bool waitAfterYield) + { + const string ErrorMessage = "This is my error!"; + var grain = GrainFactory.GetGrain(Guid.NewGuid()); + + var values = new List(); + try + { + await foreach (var entry in grain.GetValuesWithError(errorIndex, waitAfterYield, ErrorMessage).WithBatchSize(10)) + { + values.Add(entry); + Logger.LogInformation("ObservableGrain_AsyncEnumerable: {Entry}", entry); + } + } + catch (InvalidOperationException iox) + { + Assert.Equal(ErrorMessage, iox.Message); + } + + Assert.Equal(errorIndex, values.Count); + + // Check that the enumerator is disposed + var grainCalls = await grain.GetIncomingCalls(); + Assert.Contains(grainCalls, c => c.InterfaceName.Contains(nameof(IAsyncEnumerableGrainExtension)) && c.MethodName.Contains(nameof(IAsyncDisposable.DisposeAsync))); + } + [Fact, TestCategory("BVT"), TestCategory("Observable")] public async Task ObservableGrain_AsyncEnumerable_Batch() { diff --git a/test/Grains/TestGrainInterfaces/IObservableGrain.cs b/test/Grains/TestGrainInterfaces/IObservableGrain.cs index 3a80ff49df9..c2f05c5d1c5 100644 --- a/test/Grains/TestGrainInterfaces/IObservableGrain.cs +++ b/test/Grains/TestGrainInterfaces/IObservableGrain.cs @@ -1,4 +1,4 @@ -namespace UnitTests.GrainInterfaces +namespace UnitTests.GrainInterfaces { /// /// A grain which returns IAsyncEnumerable @@ -10,6 +10,7 @@ public interface IObservableGrain : IGrainWithGuidKey ValueTask Deactivate(); ValueTask OnNext(string data); IAsyncEnumerable GetValues(); + IAsyncEnumerable GetValuesWithError(int errorIndex, bool waitAfterYield, string errorMessage); ValueTask> GetIncomingCalls(); } diff --git a/test/Grains/TestInternalGrains/ObservableGrain.cs b/test/Grains/TestInternalGrains/ObservableGrain.cs index 3574e597258..8e6ce366792 100644 --- a/test/Grains/TestInternalGrains/ObservableGrain.cs +++ b/test/Grains/TestInternalGrains/ObservableGrain.cs @@ -1,4 +1,4 @@ -using System.Threading.Channels; +using System.Threading.Channels; using UnitTests.GrainInterfaces; namespace UnitTests.Grains @@ -10,6 +10,21 @@ public class ObservableGrain : Grain, IObservableGrain, IIncomingGrainCallFilter public IAsyncEnumerable GetValues() => _updates.Reader.ReadAllAsync(); + public async IAsyncEnumerable GetValuesWithError(int errorIndex, bool waitAfterYield, string errorMessage) + { + await Task.CompletedTask; + for (var i = 0; i < int.MaxValue; i++) + { + if (i == errorIndex) + { + throw new InvalidOperationException(errorMessage); + } + + yield return i; + await Task.Yield(); + } + } + public ValueTask Complete() { _updates.Writer.Complete();