diff --git a/src/Orleans.Core/Core/DefaultClientServices.cs b/src/Orleans.Core/Core/DefaultClientServices.cs index f2cdf597961..7abea68e6af 100644 --- a/src/Orleans.Core/Core/DefaultClientServices.cs +++ b/src/Orleans.Core/Core/DefaultClientServices.cs @@ -62,6 +62,11 @@ public static void AddDefaultServices(IClientBuilder builder) services.AddSingleton(); services.AddFromExisting, ClientOptionsLogger>(); + // Lifecycle + services.AddSingleton>(); + services.TryAddFromExisting>(); + services.AddFromExisting, ServiceLifecycle>(); + // Statistics services.AddSingleton(); #pragma warning disable 618 diff --git a/src/Orleans.Core/Lifecycle/ServiceLifecycle.cs b/src/Orleans.Core/Lifecycle/ServiceLifecycle.cs new file mode 100644 index 00000000000..eaaf8ca1d4e --- /dev/null +++ b/src/Orleans.Core/Lifecycle/ServiceLifecycle.cs @@ -0,0 +1,61 @@ +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Orleans; + +#nullable enable + +/// +/// Allows consumers to observe and participate in the client/silo's lifecycle. +/// +public interface IServiceLifecycle +{ + /// + /// Triggered when the client/silo has fully started and is ready to accept traffic. + /// + IServiceLifecycleStage Started { get; } + + /// + /// Triggered when the client/silo is beginning the shutdown process. + /// + IServiceLifecycleStage Stopping { get; } + + /// + /// Triggered when the client/silo has completed its shutdown process. + /// + IServiceLifecycleStage Stopped { get; } +} + +internal sealed class ServiceLifecycle(ILogger> logger) : + IServiceLifecycle, ILifecycleParticipant + where TLifecycleObservable : ILifecycleObservable +{ + private readonly ServiceLifecycleNotificationStage _started = new(logger, "Started"); + private readonly ServiceLifecycleNotificationStage _stopping = new(logger, "Stopping"); + private readonly ServiceLifecycleNotificationStage _stopped = new(logger, "Stopped"); + + public IServiceLifecycleStage Started => _started; + public IServiceLifecycleStage Stopping => _stopping; + public IServiceLifecycleStage Stopped => _stopped; + + public void Participate(TLifecycleObservable lifecycle) + { + lifecycle.Subscribe( + observerName: nameof(Started), + stage: ServiceLifecycleStage.Active, + onStart: _started.NotifyCompleted, + onStop: _ => Task.CompletedTask); + + lifecycle.Subscribe( + observerName: nameof(Stopping), + stage: ServiceLifecycleStage.Active, + onStart: _ => Task.CompletedTask, + onStop: _stopping.NotifyCompleted); + + lifecycle.Subscribe( + observerName: nameof(Stopped), + stage: ServiceLifecycleStage.RuntimeInitialize - 1, + onStart: _ => Task.CompletedTask, + onStop: _stopped.NotifyCompleted); + } +} diff --git a/src/Orleans.Core/Lifecycle/ServiceLifecycleNotificationStage.cs b/src/Orleans.Core/Lifecycle/ServiceLifecycleNotificationStage.cs new file mode 100644 index 00000000000..5e70fac52e0 --- /dev/null +++ b/src/Orleans.Core/Lifecycle/ServiceLifecycleNotificationStage.cs @@ -0,0 +1,251 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Orleans; + +#nullable enable + +/// +/// Represents a specific stage in the client / silo lifecycle. +/// +public interface IServiceLifecycleStage +{ + /// + /// Gets a cancellation token that is triggered when this stage completes. + /// + /// Avoid registering callbacks in this token, prefer + /// instead. + CancellationToken Token { get; } + + /// + /// Waits for this lifecycle stage to complete. + /// + /// + /// A token used to cancel the wait. This does not cancel the lifecycle stage itself! + /// + Task WaitAsync(CancellationToken cancellationToken = default); + + /// + /// Registers a callback to be executed during this lifecycle stage. + /// + /// + /// The asynchronous operation to perform. + /// Never call inside a callback, as it will result in a deadlock! + /// + /// + /// If true, the client / silo will shut down if there is a failure; + /// otherwise an error will be logged and the client / silo will continue to the next stage. + /// + /// An optional state to pass. + /// + /// Disposing the returned value removes the callback from the lifecycle stage. + /// This is useful for components that have a shorter lifespan than the client / silo to prevent holding onto the reference, + /// and ensure that cleanup logic is not executed for components that are no longer active. + /// + IDisposable Register(Func callback, object? state = null, bool terminateOnError = true); +} + +internal sealed partial class ServiceLifecycleNotificationStage(ILogger logger, string name) : IServiceLifecycleStage +{ + // We use this so that late registrations can still be executed, otherwise + // we'd need to rely on the TCS which means we'd need to set it *before* the callbacks + // have been executed, ideally we should fire the TCS only after non-late registered callbacks have completed. + private bool _isNotifyingOrHasCompleted; + + private readonly object _lock = new(); + private readonly List _participants = []; + private readonly CancellationTokenSource _cts = new(); + private readonly TaskCompletionSource _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public CancellationToken Token => _cts.Token; + + public Task WaitAsync(CancellationToken cancellationToken) => _tcs.Task.WaitAsync(cancellationToken); + + public IDisposable Register(Func callback, object? state, bool terminateOnError) + { + ArgumentNullException.ThrowIfNull(callback); + + var participant = new StageParticipant(this, callback, state, terminateOnError); + + lock (_lock) + { + if (_isNotifyingOrHasCompleted) + { + LogStageAlreadyCompleted(logger, name); + + _ = Task.Run(() => ExecuteLateCallback(participant)); + + return participant; + } + + _participants.Add(participant); + } + + return participant; + + async Task ExecuteLateCallback(StageParticipant participant) + { + try + { + // The original token passed to NotifyCompleted (typically related to the silo startup/shutdown) must be "gone" by now. + // Since the stage has already completed, there is no impending timeout for this late registration, so we pass CancellationToken.None. + // For late participants we do not check for termination! + + await participant.ExecuteAsync(CancellationToken.None).ConfigureAwait(false); + } + catch (Exception ex) + { + LogLateCallbackError(logger, ex, name); + } + } + } + + public async Task NotifyCompleted(CancellationToken cancellationToken) + { + List? snapshot; + + lock (_lock) + { + if (_isNotifyingOrHasCompleted) + { + snapshot = null; + } + else + { + _isNotifyingOrHasCompleted = true; + snapshot = [.. _participants]; + } + } + + if (snapshot is null) + { + await _tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + return; + } + + var tasks = new List(snapshot.Count + 1) + { + CancelTokenAsync() + }; + + foreach (var participant in snapshot) + { + tasks.Add(ExecuteParticipantAsync(participant, cancellationToken)); + } + + var allTasks = Task.WhenAll(tasks); + + try + { + await allTasks.ConfigureAwait(false); + _tcs.SetResult(); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + _tcs.TrySetCanceled(cancellationToken); + } + catch (Exception ex) + { + // Note that awaiting WhenAll returns only the first exception, and we want to show all, if there are multiple. + if (allTasks.Exception is { } aggEx) + { + var flattened = aggEx.Flatten(); + + if (flattened.InnerExceptions.Count == 1) + { + // For cleaner reporting in case one callback throws. + _tcs.SetException(flattened.InnerExceptions[0]); + } + else + { + // Otherwise we let the user see all failures. + _tcs.SetException(flattened); + } + } + else + { + // Unlikely but hey! + _tcs.SetException(ex); + } + + // We throw here regardless, because it's the callback participant who controls whether to TerminateOnError or not. + throw; + } + } + + private async Task CancelTokenAsync() + { + try + { + await _cts.CancelAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + // Should not happen if callers respect the contract to register + // callbacks with the proper method, but it can happen! + LogCancellationCallbackError(logger, ex, name); + } + } + + private async Task ExecuteParticipantAsync(StageParticipant participant, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + return; + } + + try + { + await participant.ExecuteAsync(cancellationToken).ConfigureAwait(false); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // If the upstream token triggered this, we rethrow so WhenAll knows we stopped due to cancellation. + throw; + } + catch (Exception ex) + { + LogCallbackError(logger, ex, name); + + if (participant.TerminateOnError) + { + // This will cause WhenAll to fault, eventually triggering _tcs.SetException above. + // NotifyCompleted relies on us to throw in case TerminateOnError is set to true. + throw; + } + } + } + + private void Unregister(StageParticipant participant) + { + lock (_lock) + { + _participants.Remove(participant); + } + } + + private record StageParticipant(ServiceLifecycleNotificationStage Stage, + Func Callback, object? State, bool TerminateOnError) : IDisposable + { + public Task ExecuteAsync(CancellationToken cancellationToken) => Callback(State, cancellationToken); + void IDisposable.Dispose() => Stage.Unregister(this); + } + + [LoggerMessage(Level = LogLevel.Information, Message = "Lifecycle stage = '{StageName}' has already completed. Executing callback immediately.")] + public static partial void LogStageAlreadyCompleted(ILogger logger, string stageName); + + [LoggerMessage(Level = LogLevel.Error, Message = "Error executing late-registered callback for lifecycle stage = '{StageName}'")] + public static partial void LogLateCallbackError(ILogger logger, Exception exception, string stageName); + + [LoggerMessage(Level = LogLevel.Information, Message = "Lifecycle stage = '{StageName}' has been canceled.")] + public static partial void LogStageCanceled(ILogger logger, string stageName); + + [LoggerMessage(Level = LogLevel.Error, Message = "Error executing callback for lifecycle stage = '{StageName}'")] + public static partial void LogCallbackError(ILogger logger, Exception exception, string stageName); + + [LoggerMessage(Level = LogLevel.Error, Message = "An exception occurred inside a CancellationToken callback for lifecycle stage = '{StageName}'")] + public static partial void LogCancellationCallbackError(ILogger logger, Exception exception, string stageName); +} diff --git a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs index 7e67c106b66..0fbe21b3df4 100644 --- a/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs +++ b/src/Orleans.Runtime/Hosting/DefaultSiloServices.cs @@ -1,49 +1,49 @@ #nullable enable +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.AspNetCore.Connections; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using Orleans.Configuration; +using Orleans.Configuration.Internal; using Orleans.Configuration.Validators; +using Orleans.Core; +using Orleans.GrainReferences; +using Orleans.Metadata; +using Orleans.Networking.Shared; +using Orleans.Placement.Repartitioning; +using Orleans.Providers; +using Orleans.Runtime; using Orleans.Runtime.Configuration; using Orleans.Runtime.ConsistentRing; using Orleans.Runtime.GrainDirectory; using Orleans.Runtime.MembershipService; -using Orleans.Metadata; using Orleans.Runtime.Messaging; +using Orleans.Runtime.Metadata; using Orleans.Runtime.Placement; +using Orleans.Runtime.Placement.Filtering; using Orleans.Runtime.Providers; +using Orleans.Runtime.Utilities; using Orleans.Runtime.Versions; using Orleans.Runtime.Versions.Compatibility; using Orleans.Runtime.Versions.Selector; using Orleans.Serialization; +using Orleans.Serialization.Cloning; +using Orleans.Serialization.Internal; +using Orleans.Serialization.Serializers; +using Orleans.Serialization.TypeSystem; using Orleans.Statistics; +using Orleans.Storage; using Orleans.Timers; +using Orleans.Timers.Internal; using Orleans.Versions; using Orleans.Versions.Compatibility; using Orleans.Versions.Selector; -using Orleans.Providers; -using Orleans.Runtime; -using Microsoft.Extensions.Logging; -using Orleans.Runtime.Utilities; -using System; -using System.Reflection; -using System.Linq; -using Microsoft.Extensions.Options; -using Orleans.Timers.Internal; -using Microsoft.AspNetCore.Connections; -using Orleans.Networking.Shared; -using Orleans.Configuration.Internal; -using Orleans.Runtime.Metadata; -using Orleans.GrainReferences; -using Orleans.Storage; -using Orleans.Serialization.TypeSystem; -using Orleans.Serialization.Serializers; -using Orleans.Serialization.Cloning; -using System.Collections.Generic; -using Microsoft.Extensions.Configuration; -using Orleans.Serialization.Internal; -using Orleans.Core; -using Orleans.Placement.Repartitioning; -using Orleans.Runtime.Placement.Filtering; namespace Orleans.Hosting { @@ -82,6 +82,11 @@ internal static void AddDefaultServices(ISiloBuilder builder) services.AddSingleton(); services.AddFromExisting, SiloControl>(); + // Lifecycle + services.AddSingleton>(); + services.TryAddFromExisting>(); + services.AddFromExisting, ServiceLifecycle>(); + // Statistics services.AddSingleton(); #pragma warning disable 618 diff --git a/test/NonSilo.Tests/ServiceLifecycleTests.cs b/test/NonSilo.Tests/ServiceLifecycleTests.cs new file mode 100644 index 00000000000..213af240841 --- /dev/null +++ b/test/NonSilo.Tests/ServiceLifecycleTests.cs @@ -0,0 +1,367 @@ +using System.Collections.Concurrent; +using Microsoft.Extensions.Logging; +using TestExtensions; +using Xunit; +using Xunit.Abstractions; + +#nullable enable + +namespace NonSilo.Tests; + +[TestCategory("BVT"), TestCategory("Lifecycle")] +public class ServiceLifecycleTests +{ + private readonly ServiceLifecycle _lifecycle; + private readonly CancelableSiloLifecycleSubject _subject; + private static readonly TimeSpan Timeout = TimeSpan.FromSeconds(30); + + public ServiceLifecycleTests(ITestOutputHelper output) + { + var factory = new LoggerFactory([new XunitLoggerProvider(output)]); + + _subject = new CancelableSiloLifecycleSubject(factory.CreateLogger()); + _lifecycle = new ServiceLifecycle(factory.CreateLogger>()); + + _lifecycle.Participate(_subject); + } + + private static (Task Task, IDisposable Registration) RegisterCallback( + IServiceLifecycleStage stage, + Action? action = null, + object? state = null, + bool terminateOnError = true) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var registration = stage.Register((s, ct) => + { + try + { + action?.Invoke(s, ct); + tcs.TrySetResult(s); + } + catch (Exception ex) + { + // We set the exception on the TCS so the test can inspect the specific failure of this callback. + tcs.TrySetException(ex); + + // We rethrow so the LifecycleSubject behaves according to TerminateOnError. + throw; + } + return Task.CompletedTask; + }, state, terminateOnError); + + return (tcs.Task, registration); + } + + [Fact] + public async Task BasicCallbackExecution() + { + var callbackState = "test-state"; + + var (task, _) = RegisterCallback(_lifecycle.Started, (state, ct) => { }, callbackState); + + await _subject.OnStart(); + + var result = await task.WaitAsync(Timeout); + Assert.Equal(callbackState, result); + } + + [Fact] + public async Task Stage_WaitAsync() + { + var waitTask = _lifecycle.Started.WaitAsync(); + + Assert.False(waitTask.IsCompleted); + + await _subject.OnStart(); + await waitTask.WaitAsync(Timeout); + } + + [Fact] + public async Task Stage_WaitAsync_Cancellation() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using var cts = new CancellationTokenSource(); + + // We register a blocking callback to keep the stage in a "running" state. + // This forces WaitAsync to actually block, allowing us to verify that cancelling + // the token interrupts the wait as expected. + + _lifecycle.Started.Register((_, _) => tcs.Task); + + var startTask = _subject.OnStart(); + var waitTask = _lifecycle.Started.WaitAsync(cts.Token); + + Assert.False(waitTask.IsCompleted, "WaitAsync should be paused waiting for the stage to complete"); + + await cts.CancelAsync(); + await Assert.ThrowsAsync(() => waitTask); + + tcs.SetResult(); + + await startTask; + } + + [Fact] + public async Task Stage_NotifyCompleted_IsIdempotent() + { + var stage = new ServiceLifecycleNotificationStage(Microsoft.Extensions.Logging.Abstractions.NullLogger.Instance, "Started"); + var gate = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var executionCount = 0; + + stage.Register(async (_, _) => + { + Interlocked.Increment(ref executionCount); + await gate.Task; + }, state: null, terminateOnError: true); + + var first = stage.NotifyCompleted(CancellationToken.None); + var second = stage.NotifyCompleted(CancellationToken.None); + + Assert.False(second.IsCompleted); + + gate.SetResult(); + await Task.WhenAll(first, second).WaitAsync(Timeout); + + await stage.NotifyCompleted(CancellationToken.None).WaitAsync(Timeout); + + Assert.Equal(1, executionCount); + } + + [Fact] + public async Task CallbackDisposal_PreventsExecution() + { + var (task, registration) = RegisterCallback(_lifecycle.Stopping); + + registration.Dispose(); + + await _subject.OnStart(); + await _subject.OnStop(); + + Assert.False(task.IsCompleted); + } + + [Fact] + public async Task CancellationToken_TriggeredOnStageCompletion() + { + var tcs = new TaskCompletionSource(); + + using var registration = _lifecycle.Stopping.Token.Register(tcs.SetResult); + + await _subject.OnStart(); + await _subject.OnStop(); + + await tcs.Task.WaitAsync(Timeout); + } + + [Fact] + public async Task ErrorHandling_TerminateOnErrorFalse() + { + var (task, _) = RegisterCallback( + _lifecycle.Started, + (state, ct) => throw new InvalidOperationException("Test"), + terminateOnError: false); + + await _subject.OnStart(); + + var ex = await Assert.ThrowsAsync(() => task.WaitAsync(Timeout)); + Assert.Equal("Test", ex.Message); + } + + [Fact] + public async Task ErrorHandling_TerminateOnErrorTrue() + { + var (task, _) = RegisterCallback( + _lifecycle.Started, + (state, ct) => throw new InvalidOperationException("Test"), + terminateOnError: true); + + var startTask = _subject.OnStart(); + + // This ensures the callback actually executed and we aren't just catching the lifecycle aborting. + var ex = await Assert.ThrowsAsync(() => task.WaitAsync(Timeout)); + Assert.Equal("Test", ex.Message); + + // Now verify the lifecycle start failed as expected. + await Assert.ThrowsAsync(() => startTask); + } + + [Fact] + public async Task ErrorHandling_TerminateOnErrorTrue_MultipleFailures() + { + var (task1, _) = RegisterCallback( + _lifecycle.Started, + (_, _) => throw new InvalidOperationException("first"), + terminateOnError: true); + + var (task2, _) = RegisterCallback( + _lifecycle.Started, + (_, _) => throw new ArgumentException("second"), + terminateOnError: true); + + var startTask = _subject.OnStart(); + + // We swallow the start exception initially so we can inspect the individual tasks. + try + { + await startTask; + } + catch + { + // Ignore + } + + // Now we wait for both TCS signals to complete (rather 'fail') before asserting. + // This prevents racing between the OnStart exception propagation and the TCS setting. + + try { await task1.WaitAsync(Timeout); } catch { } + try { await task2.WaitAsync(Timeout); } catch { } + + await Assert.ThrowsAsync(() => task1); + await Assert.ThrowsAsync(() => task2); + } + + [Fact] + public async Task LateRegistration_ExecutedImmediately() + { + await _subject.OnStart(); + await _subject.OnStop(); + + // Registering after stage completes should run immediately + var (task, _) = RegisterCallback(_lifecycle.Stopping); + + await task.WaitAsync(TimeSpan.FromSeconds(1)); + } + + [Fact] + public async Task ConcurrentCallbacks_RegistrationSafe() + { + const int Count = 50; + + var startSignal = new ManualResetEventSlim(false); + var tasks = new Task[Count]; + var executionCount = 0; + + for (var i = 0; i < Count; i++) + { + tasks[i] = Task.Run(() => + { + startSignal.Wait(); + RegisterCallback(_lifecycle.Started, (_, _) => Interlocked.Increment(ref executionCount)); + }); + } + + startSignal.Set(); + + await Task.WhenAll(tasks); + await _subject.OnStart(); + + Assert.Equal(Count, executionCount); + } + + [Fact] + public async Task MultipleStages_ExecuteInOrder() + { + var executionOrder = new ConcurrentQueue(); + + RegisterCallback(_lifecycle.Started, (_, _) => executionOrder.Enqueue("Started")); + RegisterCallback(_lifecycle.Stopping, (_, _) => executionOrder.Enqueue("Stopping")); + + // We capture the stopped task to wait on it specifically. + var (stoppedTask, _) = RegisterCallback(_lifecycle.Stopped, (_, _) => executionOrder.Enqueue("Stopped")); + + await _subject.OnStart(); + await _subject.OnStop(); + await stoppedTask.WaitAsync(Timeout); + + var order = executionOrder.ToArray(); + + Assert.Equal(3, order.Length); + Assert.Equal("Started", order[0]); + Assert.Equal("Stopping", order[1]); + Assert.Equal("Stopped", order[2]); + } + + [Fact] + public async Task BackgroundWorker_StopsOnCancellation() + { + var workerExited = new TaskCompletionSource(); + var token = _lifecycle.Stopping.Token; + + _ = Task.Run(async () => + { + try + { + await Task.Delay(System.Threading.Timeout.InfiniteTimeSpan, token); + } + catch (OperationCanceledException) + { + workerExited.SetResult(); + } + }); + + await _subject.OnStart(); + await _subject.OnStop(); + + await workerExited.Task.WaitAsync(Timeout); + } + + [Fact] + public async Task Lifecycle_CancellationToken_PassedToCallback() + { + var tcs = new TaskCompletionSource(); + + // We manually register here because the logic is specific to CT handling inside the callback + // and returns a Task result different from the standard flow. + _lifecycle.Started.Register(async (state, ct) => + { + try + { + await Task.Delay(System.Threading.Timeout.InfiniteTimeSpan, ct); + } + catch (OperationCanceledException) + { + tcs.SetResult(true); + } + }); + + var startTask = _subject.OnStart(); + + await _subject.CancelStartAsync(); + + try + { + await startTask; + } + catch (OperationCanceledException) + { + + } + + var tokenWasCancelled = await tcs.Task.WaitAsync(Timeout); + Assert.True(tokenWasCancelled); + } + + /// + /// A simple cancelable version of the real subject to test for cancellations. + /// + public class CancelableSiloLifecycleSubject(ILogger logger) : SiloLifecycleSubject(logger) + { + private readonly CancellationTokenSource _cts = new(); + + public override Task OnStart(CancellationToken cancellationToken = default) + { + var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _cts.Token); + return base.OnStart(linkedCts.Token); + } + + public Task CancelStartAsync() + { + _cts.Cancel(); + return Task.CompletedTask; + } + } +} +