diff --git a/Whisper.net/WhisperProcessor.cs b/Whisper.net/WhisperProcessor.cs index 5935ad07..89a2b2aa 100755 --- a/Whisper.net/WhisperProcessor.cs +++ b/Whisper.net/WhisperProcessor.cs @@ -17,8 +17,6 @@ namespace Whisper.net; /// public sealed class WhisperProcessor : IAsyncDisposable, IDisposable { - private static readonly ConcurrentDictionary processorInstances = new(); - private static long currentProcessorId; private const byte trueByte = 1; private const byte falseByte = 0; @@ -33,18 +31,11 @@ public sealed class WhisperProcessor : IAsyncDisposable, IDisposable private IntPtr? suppressRegex; private bool isDisposed; private int segmentIndex; - private CancellationToken? currentCancellationToken; - - // ID is used to identify the current instance when calling the callbacks from C++ - private readonly long myId; internal WhisperProcessor(WhisperProcessorOptions options, INativeWhisper nativeWhisper) { this.options = options; this.nativeWhisper = nativeWhisper; - myId = Interlocked.Increment(ref currentProcessorId); - - processorInstances[myId] = this; currentWhisperContext = options.ContextHandle; whisperParams = GetWhisperParams(); @@ -225,10 +216,18 @@ public unsafe void Process(ReadOnlySpan samples) processingSemaphore.Wait(); segmentIndex = 0; - var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length); - if (result != 0) + var processingContextHandle = CreateProcessingContext(CancellationToken.None, out var processingParams); + try { - throw new WhisperProcessingException(result); + var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, processingParams, (IntPtr)pData, samples.Length); + if (result != 0) + { + throw new WhisperProcessingException(result); + } + } + finally + { + processingContextHandle.Free(); } } finally @@ -272,16 +271,6 @@ void OnSegmentHandler(SegmentData segmentData) resetEvent!.Set(); } - bool OnWhisperAbortHandler() - { - if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested) - { - return true; - } - - return false; - } - try { lock (options.OnSegmentEventHandlers) @@ -289,20 +278,31 @@ bool OnWhisperAbortHandler() options.OnSegmentEventHandlers.Add(OnSegmentHandler); } - options.WhisperAbortEventHandler = OnWhisperAbortHandler; - - currentCancellationToken = cancellationToken; var processingTask = ProcessInternalAsync(samples, cancellationToken); - var whisperTask = processingTask.ContinueWith(_ => resetEvent.Set(), cancellationToken, TaskContinuationOptions.None, TaskScheduler.Default); + _ = processingTask.ContinueWith( + static (task, state) => + { + _ = task.Exception; + ((AsyncAutoResetEvent)state!).Set(); + }, + resetEvent, + CancellationToken.None, + TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + + using var cancellationRegistration = cancellationToken.Register( + static state => ((AsyncAutoResetEvent)state!).Set(), + resetEvent); while (!processingTask.IsCompleted || !buffer.IsEmpty) { - cancellationToken.ThrowIfCancellationRequested(); + ThrowTaskCanceledIfCancellationRequested(cancellationToken); if (buffer.IsEmpty) { await Task.WhenAny(processingTask, resetEvent.WaitAsync()) .ConfigureAwait(false); + ThrowTaskCanceledIfCancellationRequested(cancellationToken); } while (!buffer.IsEmpty && buffer.TryDequeue(out var segmentData)) @@ -312,10 +312,7 @@ await Task.WhenAny(processingTask, resetEvent.WaitAsync()) } await processingTask.ConfigureAwait(false); - if (cancellationToken.IsCancellationRequested) - { - throw new TaskCanceledException(); - } + ThrowTaskCanceledIfCancellationRequested(cancellationToken); while (buffer.TryDequeue(out var segmentData)) { @@ -365,7 +362,6 @@ public void Dispose() throw new Exception("Cannot dispose while processing, please use DisposeAsync instead."); } - processorInstances.TryRemove(myId, out _); MarshalUtils.TryReleaseStringHGlobal(language); language = null; MarshalUtils.TryReleaseStringHGlobal(initialPromptText); @@ -394,22 +390,37 @@ private unsafe Task ProcessInternalAsync(ReadOnlyMemory samples, Cancella { fixed (float* pData = samples.Span) { - processingSemaphore.Wait(); - segmentIndex = 0; - - var state = GetWhisperState(); + processingSemaphore.Wait(cancellationToken); + var state = IntPtr.Zero; + var processingContextHandle = default(GCHandle); try { - var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length); + segmentIndex = 0; + state = GetWhisperState(); + processingContextHandle = CreateProcessingContext(cancellationToken, out var processingParams); + + var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, processingParams, (IntPtr)pData, samples.Length); if (result != 0) { + ThrowTaskCanceledIfCancellationRequested(cancellationToken); throw new WhisperProcessingException(result); } + + ThrowTaskCanceledIfCancellationRequested(cancellationToken); } finally { - nativeWhisper.Whisper_Free_State(state); + if (processingContextHandle.IsAllocated) + { + processingContextHandle.Free(); + } + + if (state != IntPtr.Zero) + { + nativeWhisper.Whisper_Free_State(state); + } + processingSemaphore.Release(); } } @@ -445,6 +456,14 @@ private IntPtr GetWhisperState() return state; } + private static void ThrowTaskCanceledIfCancellationRequested(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + throw new TaskCanceledException(); + } + } + private WhisperFullParams GetWhisperParams() { var strategy = options.SamplingStrategy.GetNativeStrategy(); @@ -625,11 +644,6 @@ private WhisperFullParams GetWhisperParams() } } - var myIntPtrId = new IntPtr(myId); - whisperParams.OnNewSegmentUserData = myIntPtrId; - whisperParams.OnEncoderBeginUserData = myIntPtrId; - whisperParams.OnAbortUserData = myIntPtrId; - #if NETSTANDARD // For netframework, we don't have `UnmanagedCallersOnlyAttribute` so we need to use a delegate wrapped with a GC handle var onNewSegmentDelegate = new WhisperNewSegmentCallback(OnNewSegmentStatic); @@ -653,7 +667,6 @@ private WhisperFullParams GetWhisperParams() gcHandle = GCHandle.Alloc(onProgressDelegate); gcHandles.Add(gcHandle); whisperParams.OnProgressCallback = Marshal.GetFunctionPointerForDelegate(onProgressDelegate); - whisperParams.OnProgressCallbackUserData = myIntPtrId; } #else unsafe @@ -671,7 +684,6 @@ private WhisperFullParams GetWhisperParams() { delegate* unmanaged[Cdecl] onProgressDelegate = &OnProgressStatic; whisperParams.OnProgressCallback = (IntPtr)onProgressDelegate; - whisperParams.OnProgressCallbackUserData = myIntPtrId; } } #endif @@ -679,17 +691,41 @@ private WhisperFullParams GetWhisperParams() return whisperParams; } + private GCHandle CreateProcessingContext(CancellationToken cancellationToken, out WhisperFullParams processingParams) + { + var processingContext = new ProcessingContext(this, cancellationToken); + var processingContextHandle = GCHandle.Alloc(processingContext); + var processingContextPtr = GCHandle.ToIntPtr(processingContextHandle); + + processingParams = whisperParams; + processingParams.OnNewSegmentUserData = processingContextPtr; + processingParams.OnEncoderBeginUserData = processingContextPtr; + processingParams.OnAbortUserData = processingContextPtr; + processingParams.OnProgressCallbackUserData = processingContextPtr; + + return processingContextHandle; + } + + private static ProcessingContext GetProcessingContext(IntPtr userData) + { + if (userData == IntPtr.Zero) + { + throw new Exception("Couldn't find processing context for user data"); + } + + var handle = GCHandle.FromIntPtr(userData); + return handle.Target as ProcessingContext + ?? throw new Exception("Couldn't find processing context for user data"); + } + #if !NETSTANDARD [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] #endif private static byte OnWhisperAbortStatic(IntPtr userData) { - if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor)) - { - throw new Exception("Couldn't find processor instance for user data"); - } - - var shouldCancel = processor.options.WhisperAbortEventHandler?.Invoke() ?? false; + var processingContext = GetProcessingContext(userData); + var shouldCancel = processingContext.CancellationToken.IsCancellationRequested + || (processingContext.Processor.options.WhisperAbortEventHandler?.Invoke() ?? false); return shouldCancel ? trueByte : falseByte; } @@ -698,12 +734,8 @@ private static byte OnWhisperAbortStatic(IntPtr userData) #endif private static void OnNewSegmentStatic(IntPtr ctx, IntPtr state, int nNew, IntPtr userData) { - if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor)) - { - throw new Exception("Couldn't find processor instance for user data"); - } - - processor.OnNewSegment(state); + var processingContext = GetProcessingContext(userData); + processingContext.Processor.OnNewSegment(state, processingContext.CancellationToken); } #if !NETSTANDARD @@ -711,12 +743,8 @@ private static void OnNewSegmentStatic(IntPtr ctx, IntPtr state, int nNew, IntPt #endif private static byte OnEncoderBeginStatic(IntPtr ctx, IntPtr state, IntPtr userData) { - if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor)) - { - throw new Exception("Couldn't find processor instance for user data"); - } - - return processor.OnEncoderBegin() ? trueByte : falseByte; + var processingContext = GetProcessingContext(userData); + return processingContext.Processor.OnEncoderBegin(processingContext.CancellationToken) ? trueByte : falseByte; } #if !NETSTANDARD @@ -724,17 +752,13 @@ private static byte OnEncoderBeginStatic(IntPtr ctx, IntPtr state, IntPtr userDa #endif private static void OnProgressStatic(IntPtr ctx, IntPtr state, int progress, IntPtr userData) { - if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor)) - { - throw new Exception("Couldn't find processor instance for user data"); - } - - processor.OnProgress(progress); + var processingContext = GetProcessingContext(userData); + processingContext.Processor.OnProgress(progress, processingContext.CancellationToken); } - private void OnProgress(int progress) + private void OnProgress(int progress, CancellationToken cancellationToken) { - if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) { return; } @@ -742,16 +766,16 @@ private void OnProgress(int progress) foreach (var handler in options.OnProgressHandlers) { handler?.Invoke(progress); - if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) { return; } } } - private bool OnEncoderBegin() + private bool OnEncoderBegin(CancellationToken cancellationToken) { - if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) { return false; } @@ -769,9 +793,9 @@ private bool OnEncoderBegin() return true; } - private void OnNewSegment(IntPtr state) + private void OnNewSegment(IntPtr state, CancellationToken cancellationToken) { - if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) { return; } @@ -859,7 +883,7 @@ private void OnNewSegment(IntPtr state) foreach (var handler in handlers) { handler?.Invoke(eventHandlerArgs); - if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested) + if (cancellationToken.IsCancellationRequested) { return; } @@ -891,4 +915,11 @@ public async ValueTask DisposeAsync() processingSemaphore.Release(); Dispose(); } + + private sealed class ProcessingContext(WhisperProcessor processor, CancellationToken cancellationToken) + { + public WhisperProcessor Processor { get; } = processor; + + public CancellationToken CancellationToken { get; } = cancellationToken; + } } diff --git a/tests/Whisper.net.Tests/ProcessingFailureTests.cs b/tests/Whisper.net.Tests/ProcessingFailureTests.cs index 3a89b949..4e0d4107 100644 --- a/tests/Whisper.net.Tests/ProcessingFailureTests.cs +++ b/tests/Whisper.net.Tests/ProcessingFailureTests.cs @@ -13,10 +13,10 @@ private sealed class FakeNativeWhisper : INativeWhisper { private readonly int _errorCode; - public FakeNativeWhisper(int errorCode) + public FakeNativeWhisper(int errorCode, INativeWhisper.whisper_full_with_state? whisperFullWithState = null) { _errorCode = errorCode; - Whisper_Full_With_State = (context, state, p, samples, n) => _errorCode; + Whisper_Full_With_State = whisperFullWithState ?? ((context, state, p, samples, n) => _errorCode); Whisper_Init_State = _ => new IntPtr(1); Whisper_Free_State = _ => { }; Whisper_Full_Default_Params_By_Ref = strategy => @@ -124,4 +124,95 @@ await Assert.ThrowsAsync(async () => } }); } + + [Fact] + public async Task ProcessAsync_WhenCancelledBeforeSegment_ShouldWakeEnumeratorAndSignalNativeAbort() + { + var nativeStarted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var nativeAbortObserved = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var nativeFinished = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var allowNativeReturn = new ManualResetEventSlim(); + + using var native = new FakeNativeWhisper(-6, (_, _, parameters, _, _) => + { + try + { + nativeStarted.SetResult(null); + var abortCallback = Marshal.GetDelegateForFunctionPointer(parameters.OnAbort); + + while (abortCallback(parameters.OnAbortUserData) == 0) + { + Thread.Sleep(10); + } + + nativeAbortObserved.SetResult(null); + allowNativeReturn.Wait(TimeSpan.FromSeconds(5)); + return -6; + } + finally + { + nativeFinished.SetResult(null); + } + }); + + var options = new WhisperProcessorOptions { ContextHandle = IntPtr.Zero }; + await using var processor = new WhisperProcessor(options, native); + using var cts = new CancellationTokenSource(); + + var processingTask = Task.Run(async () => + { + await foreach (var _ in processor.ProcessAsync(new float[1], cts.Token)) + { + } + }); + + try + { + await WaitForTaskAsync(nativeStarted.Task); + cts.Cancel(); + await WaitForTaskAsync(nativeAbortObserved.Task); + + await Assert.ThrowsAsync(() => WaitForTaskAsync(processingTask, TimeSpan.FromSeconds(1))); + Assert.False(nativeFinished.Task.IsCompleted); + } + finally + { + allowNativeReturn.Set(); + } + + await WaitForTaskAsync(nativeFinished.Task); + } + + [Fact] + public async Task ProcessAsync_WhenCancelledBeforeNativeFailure_ShouldThrowTaskCanceledException() + { + using var cts = new CancellationTokenSource(); + using var native = new FakeNativeWhisper(-6, (_, _, _, _, _) => + { + cts.Cancel(); + return -6; + }); + + var options = new WhisperProcessorOptions { ContextHandle = IntPtr.Zero }; + await using var processor = new WhisperProcessor(options, native); + + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in processor.ProcessAsync(new float[1], cts.Token)) + { + } + }); + } + + private static async Task WaitForTaskAsync(Task task) + { + await WaitForTaskAsync(task, TimeSpan.FromSeconds(5)); + } + + private static async Task WaitForTaskAsync(Task task, TimeSpan timeout) + { + var completedTask = await Task.WhenAny(task, Task.Delay(timeout)); + Assert.Same(task, completedTask); + await task; + } } diff --git a/tests/Whisper.net.Tests/Whisper.net.Tests.csproj b/tests/Whisper.net.Tests/Whisper.net.Tests.csproj index 6ea77f2c..7e437252 100644 --- a/tests/Whisper.net.Tests/Whisper.net.Tests.csproj +++ b/tests/Whisper.net.Tests/Whisper.net.Tests.csproj @@ -17,6 +17,8 @@ enable 13 true + + false net9.0;net10.0 xUnit1041