From 9b627f0bd7fa9cf0fc8c48b2c3942514471bbd91 Mon Sep 17 00:00:00 2001 From: sandrohanea <40202887+sandrohanea@users.noreply.github.com> Date: Fri, 2 Jun 2023 19:03:47 +0200 Subject: [PATCH] Fixed #57 Using WhisperProcessor.ProcessAsync more than once (#58) * Fixed #57 Using WhisperProcessor.ProcessAsync more than once * Passed cancellationToken to processingSemaphore Wait and called WaitAsync * Simplified the initialization for ProcessInternalAsync --- ...ests.cs => ProcessAsyncFunctionalTests.cs} | 110 +++---------- Whisper.net.Tests/ProcessFunctionalTests.cs | 154 ++++++++++++++++++ Whisper.net/WhisperProcessor.cs | 9 +- 3 files changed, 186 insertions(+), 87 deletions(-) rename Whisper.net.Tests/{ProcessorE2ETests.cs => ProcessAsyncFunctionalTests.cs} (61%) create mode 100644 Whisper.net.Tests/ProcessFunctionalTests.cs diff --git a/Whisper.net.Tests/ProcessorE2ETests.cs b/Whisper.net.Tests/ProcessAsyncFunctionalTests.cs similarity index 61% rename from Whisper.net.Tests/ProcessorE2ETests.cs rename to Whisper.net.Tests/ProcessAsyncFunctionalTests.cs index a02a75b7..886678dd 100644 --- a/Whisper.net.Tests/ProcessorE2ETests.cs +++ b/Whisper.net.Tests/ProcessAsyncFunctionalTests.cs @@ -5,8 +5,7 @@ using Whisper.net.Ggml; namespace Whisper.net.Tests; - -public class ProcessorE2ETests +public class ProcessAsyncFunctionalTests { private string ggmlModelPath = string.Empty; @@ -25,35 +24,6 @@ public void TearDown() File.Delete(ggmlModelPath); } - [Test] - public void TestHappyFlow() - { - var segments = new List(); - var progress = new List(); - var encoderBegins = new List(); - using var factory = WhisperFactory.FromPath(ggmlModelPath); - using var processor = factory.CreateBuilder() - .WithLanguage("en") - .WithEncoderBeginHandler((e) => - { - encoderBegins.Add(e); - return true; - }) - .WithPrompt("I am Kennedy") - .WithProgressHandler(progress.Add) - .WithSegmentEventHandler(segments.Add) - .Build(); - - using var fileReader = File.OpenRead("kennedy.wav"); - processor.Process(fileReader); - - segments.Should().HaveCountGreaterThan(0); - encoderBegins.Should().HaveCount(1); - progress.Should().BeInAscendingOrder().And.HaveCountGreaterThan(1); - - segments.Should().Contain(segmentData => segmentData.Text.Contains("nation should commit")); - } - [Test] public async Task TestHappyFlowAsync() { @@ -89,54 +59,6 @@ public async Task TestHappyFlowAsync() segments.Should().Contain(segmentData => segmentData.Text.Contains("nation should commit")); } - [Test] - public void TestCancelEncoder() - { - var segments = new List(); - var encoderBegins = new List(); - using var factory = WhisperFactory.FromPath(ggmlModelPath); - using var processor = factory.CreateBuilder() - .WithLanguage("en") - .WithEncoderBeginHandler((e) => - { - encoderBegins.Add(e); - return false; - }) - .WithSegmentEventHandler(segments.Add) - .Build(); - - using var fileReader = File.OpenRead("kennedy.wav"); - processor.Process(fileReader); - - segments.Should().HaveCount(0); - encoderBegins.Should().HaveCount(1); - } - - [Test] - public async Task TestAutoDetectLanguageWithRomanian() - { - var segments = new List(); - var encoderBegins = new List(); - using var factory = WhisperFactory.FromPath(ggmlModelPath); - using var processor = factory.CreateBuilder() - .WithLanguageDetection() - .WithEncoderBeginHandler((e) => - { - encoderBegins.Add(e); - return true; - }) - .Build(); - using var fileReader = File.OpenRead("romana.wav"); - await foreach (var segment in processor.ProcessAsync(fileReader)) - { - segments.Add(segment); - } - segments.Should().HaveCountGreaterThan(0); - encoderBegins.Should().HaveCount(1); - segments.Should().AllSatisfy(s => s.Language.Should().Be("ro")); - segments.Should().Contain(segmentData => segmentData.Text.Contains("efectua")); - } - [Test] public async Task ProcessAsync_Cancelled_WillCancellTheProcessing_AndDispose_WillWaitUntilFullyFinished() { @@ -224,19 +146,37 @@ public async Task ProcessAsync_WhenMultichannel_ProcessCorrectly() } [Test] - public async Task Process_WhenMultichannel_ProcessCorrectly() + public async Task ProcessAsync_CalledMultipleTimes_Serially_WillCompleteEverytime() { - var segments = new List(); + + var segments1 = new List(); + var segments2 = new List(); + var segments3 = new List(); using var factory = WhisperFactory.FromPath(ggmlModelPath); await using var processor = factory.CreateBuilder() .WithLanguage("en") - .WithSegmentEventHandler(segments.Add) .Build(); - using var fileReader = File.OpenRead("multichannel.wav"); - processor.Process(fileReader); + using var fileReader = File.OpenRead("kennedy.wav"); + await foreach (var segment in processor.ProcessAsync(fileReader)) + { + segments1.Add(segment); + } - segments.Should().HaveCountGreaterThanOrEqualTo(1); + using var fileReader2 = File.OpenRead("kennedy.wav"); + await foreach (var segment in processor.ProcessAsync(fileReader2)) + { + segments2.Add(segment); + } + + using var fileReader3 = File.OpenRead("kennedy.wav"); + await foreach (var segment in processor.ProcessAsync(fileReader3)) + { + segments3.Add(segment); + } + + segments1.Should().BeEquivalentTo(segments2); + segments2.Should().BeEquivalentTo(segments3); } } diff --git a/Whisper.net.Tests/ProcessFunctionalTests.cs b/Whisper.net.Tests/ProcessFunctionalTests.cs new file mode 100644 index 00000000..d0478237 --- /dev/null +++ b/Whisper.net.Tests/ProcessFunctionalTests.cs @@ -0,0 +1,154 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + +using FluentAssertions; +using NUnit.Framework; +using Whisper.net.Ggml; + +namespace Whisper.net.Tests; + +public class ProcessFunctionalTests +{ + private string ggmlModelPath = string.Empty; + + [OneTimeSetUp] + public async Task SetupAsync() + { + ggmlModelPath = Path.GetTempFileName(); + var model = await WhisperGgmlDownloader.GetGgmlModelAsync(GgmlType.Tiny); + using var fileWriter = File.OpenWrite(ggmlModelPath); + await model.CopyToAsync(fileWriter); + } + + [OneTimeTearDown] + public void TearDown() + { + File.Delete(ggmlModelPath); + } + + [Test] + public void TestHappyFlow() + { + var segments = new List(); + var progress = new List(); + var encoderBegins = new List(); + using var factory = WhisperFactory.FromPath(ggmlModelPath); + using var processor = factory.CreateBuilder() + .WithLanguage("en") + .WithEncoderBeginHandler((e) => + { + encoderBegins.Add(e); + return true; + }) + .WithPrompt("I am Kennedy") + .WithProgressHandler(progress.Add) + .WithSegmentEventHandler(segments.Add) + .Build(); + + using var fileReader = File.OpenRead("kennedy.wav"); + processor.Process(fileReader); + + segments.Should().HaveCountGreaterThan(0); + encoderBegins.Should().HaveCount(1); + progress.Should().BeInAscendingOrder().And.HaveCountGreaterThan(1); + + segments.Should().Contain(segmentData => segmentData.Text.Contains("nation should commit")); + } + + [Test] + public void TestCancelEncoder() + { + var segments = new List(); + var encoderBegins = new List(); + using var factory = WhisperFactory.FromPath(ggmlModelPath); + using var processor = factory.CreateBuilder() + .WithLanguage("en") + .WithEncoderBeginHandler((e) => + { + encoderBegins.Add(e); + return false; + }) + .WithSegmentEventHandler(segments.Add) + .Build(); + + using var fileReader = File.OpenRead("kennedy.wav"); + processor.Process(fileReader); + + segments.Should().HaveCount(0); + encoderBegins.Should().HaveCount(1); + } + + [Test] + public async Task TestAutoDetectLanguageWithRomanian() + { + var segments = new List(); + var encoderBegins = new List(); + using var factory = WhisperFactory.FromPath(ggmlModelPath); + using var processor = factory.CreateBuilder() + .WithLanguageDetection() + .WithEncoderBeginHandler((e) => + { + encoderBegins.Add(e); + return true; + }) + .Build(); + using var fileReader = File.OpenRead("romana.wav"); + await foreach (var segment in processor.ProcessAsync(fileReader)) + { + segments.Add(segment); + } + segments.Should().HaveCountGreaterThan(0); + encoderBegins.Should().HaveCount(1); + segments.Should().AllSatisfy(s => s.Language.Should().Be("ro")); + segments.Should().Contain(segmentData => segmentData.Text.Contains("efectua")); + } + + [Test] + public async Task Process_WhenMultichannel_ProcessCorrectly() + { + var segments = new List(); + + using var factory = WhisperFactory.FromPath(ggmlModelPath); + await using var processor = factory.CreateBuilder() + .WithLanguage("en") + .WithSegmentEventHandler(segments.Add) + .Build(); + + using var fileReader = File.OpenRead("multichannel.wav"); + processor.Process(fileReader); + + segments.Should().HaveCountGreaterThanOrEqualTo(1); + } + + [Test] + public async Task Process_CalledMultipleTimes_Serially_WillCompleteEverytime() + { + + var segments1 = new List(); + var segments2 = new List(); + var segments3 = new List(); + + OnSegmentEventHandler onNewSegment = segments1.Add; + + using var factory = WhisperFactory.FromPath(ggmlModelPath); + await using var processor = factory.CreateBuilder() + .WithLanguage("en") + .WithSegmentEventHandler((s) => onNewSegment(s)) + .Build(); + + using var fileReader1 = File.OpenRead("kennedy.wav"); + processor.Process(fileReader1); + + onNewSegment = segments2.Add; + + using var fileReader2 = File.OpenRead("kennedy.wav"); + processor.Process(fileReader2); + + onNewSegment = segments3.Add; + + using var fileReader3 = File.OpenRead("kennedy.wav"); + processor.Process(fileReader3); + + segments1.Should().BeEquivalentTo(segments2); + segments2.Should().BeEquivalentTo(segments3); + } +} diff --git a/Whisper.net/WhisperProcessor.cs b/Whisper.net/WhisperProcessor.cs index 0e4d961a..c2d9daa2 100755 --- a/Whisper.net/WhisperProcessor.cs +++ b/Whisper.net/WhisperProcessor.cs @@ -122,6 +122,8 @@ public unsafe void Process(float[] samples) try { processingSemaphore.Wait(); + segmentIndex = 0; + NativeMethods.whisper_full_with_state(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length); } finally @@ -153,10 +155,10 @@ void OnSegmentHandler(SegmentData segmentData) resetEvent!.Set(); } - options.OnSegmentEventHandlers.Add(OnSegmentHandler); - try { + options.OnSegmentEventHandlers.Add(OnSegmentHandler); + currentCancellationToken = cancellationToken; var whisperTask = ProcessInternalAsync(samples, cancellationToken) .ContinueWith(_ => resetEvent.Set(), cancellationToken, TaskContinuationOptions.None, TaskScheduler.Default); @@ -225,7 +227,10 @@ private unsafe Task ProcessInternalAsync(float[] samples, CancellationToken canc fixed (float* pData = samples) { processingSemaphore.Wait(); + segmentIndex = 0; + var state = NativeMethods.whisper_init_state(currentWhisperContext); + try { NativeMethods.whisper_full_with_state(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length);