Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 108 additions & 77 deletions Whisper.net/WhisperProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ namespace Whisper.net;
/// </summary>
public sealed class WhisperProcessor : IAsyncDisposable, IDisposable
{
private static readonly ConcurrentDictionary<long, WhisperProcessor> processorInstances = new();
private static long currentProcessorId;
private const byte trueByte = 1;
private const byte falseByte = 0;

Expand All @@ -33,18 +31,11 @@ public sealed class WhisperProcessor : IAsyncDisposable, IDisposable
private IntPtr? suppressRegex;
private bool isDisposed;
private int segmentIndex;
private CancellationToken? currentCancellationToken;

// ID is used to identify the current instance when calling the callbacks from C++
private readonly long myId;

internal WhisperProcessor(WhisperProcessorOptions options, INativeWhisper nativeWhisper)
{
this.options = options;
this.nativeWhisper = nativeWhisper;
myId = Interlocked.Increment(ref currentProcessorId);

processorInstances[myId] = this;

currentWhisperContext = options.ContextHandle;
whisperParams = GetWhisperParams();
Expand Down Expand Up @@ -225,10 +216,18 @@ public unsafe void Process(ReadOnlySpan<float> samples)
processingSemaphore.Wait();
segmentIndex = 0;

var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length);
if (result != 0)
var processingContextHandle = CreateProcessingContext(CancellationToken.None, out var processingParams);
try
{
throw new WhisperProcessingException(result);
var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, processingParams, (IntPtr)pData, samples.Length);
if (result != 0)
{
throw new WhisperProcessingException(result);
}
}
finally
{
processingContextHandle.Free();
}
}
finally
Expand Down Expand Up @@ -272,37 +271,38 @@ void OnSegmentHandler(SegmentData segmentData)
resetEvent!.Set();
}

bool OnWhisperAbortHandler()
{
if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested)
{
return true;
}

return false;
}

try
{
lock (options.OnSegmentEventHandlers)
{
options.OnSegmentEventHandlers.Add(OnSegmentHandler);
}

options.WhisperAbortEventHandler = OnWhisperAbortHandler;

currentCancellationToken = cancellationToken;
var processingTask = ProcessInternalAsync(samples, cancellationToken);
var whisperTask = processingTask.ContinueWith(_ => resetEvent.Set(), cancellationToken, TaskContinuationOptions.None, TaskScheduler.Default);
_ = processingTask.ContinueWith(
static (task, state) =>
{
_ = task.Exception;
((AsyncAutoResetEvent)state!).Set();
},
resetEvent,
CancellationToken.None,
TaskContinuationOptions.ExecuteSynchronously,
TaskScheduler.Default);

using var cancellationRegistration = cancellationToken.Register(
static state => ((AsyncAutoResetEvent)state!).Set(),
resetEvent);

while (!processingTask.IsCompleted || !buffer.IsEmpty)
{
cancellationToken.ThrowIfCancellationRequested();
ThrowTaskCanceledIfCancellationRequested(cancellationToken);

if (buffer.IsEmpty)
{
await Task.WhenAny(processingTask, resetEvent.WaitAsync())
.ConfigureAwait(false);
ThrowTaskCanceledIfCancellationRequested(cancellationToken);
}

while (!buffer.IsEmpty && buffer.TryDequeue(out var segmentData))
Expand All @@ -312,10 +312,7 @@ await Task.WhenAny(processingTask, resetEvent.WaitAsync())
}

await processingTask.ConfigureAwait(false);
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
ThrowTaskCanceledIfCancellationRequested(cancellationToken);

while (buffer.TryDequeue(out var segmentData))
{
Expand Down Expand Up @@ -365,7 +362,6 @@ public void Dispose()
throw new Exception("Cannot dispose while processing, please use DisposeAsync instead.");
}

processorInstances.TryRemove(myId, out _);
MarshalUtils.TryReleaseStringHGlobal(language);
language = null;
MarshalUtils.TryReleaseStringHGlobal(initialPromptText);
Expand Down Expand Up @@ -394,22 +390,37 @@ private unsafe Task ProcessInternalAsync(ReadOnlyMemory<float> samples, Cancella
{
fixed (float* pData = samples.Span)
{
processingSemaphore.Wait();
segmentIndex = 0;

var state = GetWhisperState();
processingSemaphore.Wait(cancellationToken);
var state = IntPtr.Zero;
var processingContextHandle = default(GCHandle);

try
{
var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length);
segmentIndex = 0;
state = GetWhisperState();
processingContextHandle = CreateProcessingContext(cancellationToken, out var processingParams);

var result = nativeWhisper.Whisper_Full_With_State(currentWhisperContext, state, processingParams, (IntPtr)pData, samples.Length);
if (result != 0)
{
ThrowTaskCanceledIfCancellationRequested(cancellationToken);
throw new WhisperProcessingException(result);
}

ThrowTaskCanceledIfCancellationRequested(cancellationToken);
}
finally
{
nativeWhisper.Whisper_Free_State(state);
if (processingContextHandle.IsAllocated)
{
processingContextHandle.Free();
}

if (state != IntPtr.Zero)
{
nativeWhisper.Whisper_Free_State(state);
}

processingSemaphore.Release();
}
}
Expand Down Expand Up @@ -445,6 +456,14 @@ private IntPtr GetWhisperState()
return state;
}

private static void ThrowTaskCanceledIfCancellationRequested(CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
}

private WhisperFullParams GetWhisperParams()
{
var strategy = options.SamplingStrategy.GetNativeStrategy();
Expand Down Expand Up @@ -625,11 +644,6 @@ private WhisperFullParams GetWhisperParams()
}
}

var myIntPtrId = new IntPtr(myId);
whisperParams.OnNewSegmentUserData = myIntPtrId;
whisperParams.OnEncoderBeginUserData = myIntPtrId;
whisperParams.OnAbortUserData = myIntPtrId;

#if NETSTANDARD
// For netframework, we don't have `UnmanagedCallersOnlyAttribute` so we need to use a delegate wrapped with a GC handle
var onNewSegmentDelegate = new WhisperNewSegmentCallback(OnNewSegmentStatic);
Expand All @@ -653,7 +667,6 @@ private WhisperFullParams GetWhisperParams()
gcHandle = GCHandle.Alloc(onProgressDelegate);
gcHandles.Add(gcHandle);
whisperParams.OnProgressCallback = Marshal.GetFunctionPointerForDelegate(onProgressDelegate);
whisperParams.OnProgressCallbackUserData = myIntPtrId;
}
#else
unsafe
Expand All @@ -671,25 +684,48 @@ private WhisperFullParams GetWhisperParams()
{
delegate* unmanaged[Cdecl]<IntPtr, IntPtr, int, IntPtr, void> onProgressDelegate = &OnProgressStatic;
whisperParams.OnProgressCallback = (IntPtr)onProgressDelegate;
whisperParams.OnProgressCallbackUserData = myIntPtrId;
}
}
#endif

return whisperParams;
}

private GCHandle CreateProcessingContext(CancellationToken cancellationToken, out WhisperFullParams processingParams)
{
var processingContext = new ProcessingContext(this, cancellationToken);
var processingContextHandle = GCHandle.Alloc(processingContext);
var processingContextPtr = GCHandle.ToIntPtr(processingContextHandle);

processingParams = whisperParams;
processingParams.OnNewSegmentUserData = processingContextPtr;
processingParams.OnEncoderBeginUserData = processingContextPtr;
processingParams.OnAbortUserData = processingContextPtr;
processingParams.OnProgressCallbackUserData = processingContextPtr;

return processingContextHandle;
}

private static ProcessingContext GetProcessingContext(IntPtr userData)
{
if (userData == IntPtr.Zero)
{
throw new Exception("Couldn't find processing context for user data");
}

var handle = GCHandle.FromIntPtr(userData);
return handle.Target as ProcessingContext
?? throw new Exception("Couldn't find processing context for user data");
}

#if !NETSTANDARD
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
#endif
private static byte OnWhisperAbortStatic(IntPtr userData)
{
if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor))
{
throw new Exception("Couldn't find processor instance for user data");
}

var shouldCancel = processor.options.WhisperAbortEventHandler?.Invoke() ?? false;
var processingContext = GetProcessingContext(userData);
var shouldCancel = processingContext.CancellationToken.IsCancellationRequested
|| (processingContext.Processor.options.WhisperAbortEventHandler?.Invoke() ?? false);
return shouldCancel ? trueByte : falseByte;
}

Expand All @@ -698,60 +734,48 @@ private static byte OnWhisperAbortStatic(IntPtr userData)
#endif
private static void OnNewSegmentStatic(IntPtr ctx, IntPtr state, int nNew, IntPtr userData)
{
if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor))
{
throw new Exception("Couldn't find processor instance for user data");
}

processor.OnNewSegment(state);
var processingContext = GetProcessingContext(userData);
processingContext.Processor.OnNewSegment(state, processingContext.CancellationToken);
}

#if !NETSTANDARD
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
#endif
private static byte OnEncoderBeginStatic(IntPtr ctx, IntPtr state, IntPtr userData)
{
if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor))
{
throw new Exception("Couldn't find processor instance for user data");
}

return processor.OnEncoderBegin() ? trueByte : falseByte;
var processingContext = GetProcessingContext(userData);
return processingContext.Processor.OnEncoderBegin(processingContext.CancellationToken) ? trueByte : falseByte;
}

#if !NETSTANDARD
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
#endif
private static void OnProgressStatic(IntPtr ctx, IntPtr state, int progress, IntPtr userData)
{
if (!processorInstances.TryGetValue(userData.ToInt64(), out var processor))
{
throw new Exception("Couldn't find processor instance for user data");
}

processor.OnProgress(progress);
var processingContext = GetProcessingContext(userData);
processingContext.Processor.OnProgress(progress, processingContext.CancellationToken);
}

private void OnProgress(int progress)
private void OnProgress(int progress, CancellationToken cancellationToken)
{
if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested)
if (cancellationToken.IsCancellationRequested)
{
return;
}

foreach (var handler in options.OnProgressHandlers)
{
handler?.Invoke(progress);
if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested)
if (cancellationToken.IsCancellationRequested)
{
return;
}
}
}

private bool OnEncoderBegin()
private bool OnEncoderBegin(CancellationToken cancellationToken)
{
if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested)
if (cancellationToken.IsCancellationRequested)
{
return false;
}
Expand All @@ -769,9 +793,9 @@ private bool OnEncoderBegin()
return true;
}

private void OnNewSegment(IntPtr state)
private void OnNewSegment(IntPtr state, CancellationToken cancellationToken)
{
if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested)
if (cancellationToken.IsCancellationRequested)
{
return;
}
Expand Down Expand Up @@ -859,7 +883,7 @@ private void OnNewSegment(IntPtr state)
foreach (var handler in handlers)
{
handler?.Invoke(eventHandlerArgs);
if (currentCancellationToken.HasValue && currentCancellationToken.Value.IsCancellationRequested)
if (cancellationToken.IsCancellationRequested)
{
return;
}
Expand Down Expand Up @@ -891,4 +915,11 @@ public async ValueTask DisposeAsync()
processingSemaphore.Release();
Dispose();
}

private sealed class ProcessingContext(WhisperProcessor processor, CancellationToken cancellationToken)
{
public WhisperProcessor Processor { get; } = processor;

public CancellationToken CancellationToken { get; } = cancellationToken;
}
}
Loading
Loading