Skip to content

Commit

Permalink
Fixed #57 Using WhisperProcessor.ProcessAsync more than once (#58)
Browse files Browse the repository at this point in the history
* Fixed #57 Using WhisperProcessor.ProcessAsync more than once

* Passed cancellationToken to processingSemaphore Wait and called WaitAsync

* Simplified the initialization for ProcessInternalAsync
  • Loading branch information
sandrohanea authored Jun 2, 2023
1 parent 51113a8 commit 9b627f0
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
using Whisper.net.Ggml;

namespace Whisper.net.Tests;

public class ProcessorE2ETests
public class ProcessAsyncFunctionalTests
{
private string ggmlModelPath = string.Empty;

Expand All @@ -25,35 +24,6 @@ public void TearDown()
File.Delete(ggmlModelPath);
}

[Test]
public void TestHappyFlow()
{
var segments = new List<SegmentData>();
var progress = new List<int>();
var encoderBegins = new List<EncoderBeginData>();
using var factory = WhisperFactory.FromPath(ggmlModelPath);
using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithEncoderBeginHandler((e) =>
{
encoderBegins.Add(e);
return true;
})
.WithPrompt("I am Kennedy")
.WithProgressHandler(progress.Add)
.WithSegmentEventHandler(segments.Add)
.Build();

using var fileReader = File.OpenRead("kennedy.wav");
processor.Process(fileReader);

segments.Should().HaveCountGreaterThan(0);
encoderBegins.Should().HaveCount(1);
progress.Should().BeInAscendingOrder().And.HaveCountGreaterThan(1);

segments.Should().Contain(segmentData => segmentData.Text.Contains("nation should commit"));
}

[Test]
public async Task TestHappyFlowAsync()
{
Expand Down Expand Up @@ -89,54 +59,6 @@ public async Task TestHappyFlowAsync()
segments.Should().Contain(segmentData => segmentData.Text.Contains("nation should commit"));
}

[Test]
public void TestCancelEncoder()
{
var segments = new List<SegmentData>();
var encoderBegins = new List<EncoderBeginData>();
using var factory = WhisperFactory.FromPath(ggmlModelPath);
using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithEncoderBeginHandler((e) =>
{
encoderBegins.Add(e);
return false;
})
.WithSegmentEventHandler(segments.Add)
.Build();

using var fileReader = File.OpenRead("kennedy.wav");
processor.Process(fileReader);

segments.Should().HaveCount(0);
encoderBegins.Should().HaveCount(1);
}

[Test]
public async Task TestAutoDetectLanguageWithRomanian()
{
var segments = new List<SegmentData>();
var encoderBegins = new List<EncoderBeginData>();
using var factory = WhisperFactory.FromPath(ggmlModelPath);
using var processor = factory.CreateBuilder()
.WithLanguageDetection()
.WithEncoderBeginHandler((e) =>
{
encoderBegins.Add(e);
return true;
})
.Build();
using var fileReader = File.OpenRead("romana.wav");
await foreach (var segment in processor.ProcessAsync(fileReader))
{
segments.Add(segment);
}
segments.Should().HaveCountGreaterThan(0);
encoderBegins.Should().HaveCount(1);
segments.Should().AllSatisfy(s => s.Language.Should().Be("ro"));
segments.Should().Contain(segmentData => segmentData.Text.Contains("efectua"));
}

[Test]
public async Task ProcessAsync_Cancelled_WillCancellTheProcessing_AndDispose_WillWaitUntilFullyFinished()
{
Expand Down Expand Up @@ -224,19 +146,37 @@ public async Task ProcessAsync_WhenMultichannel_ProcessCorrectly()
}

[Test]
public async Task Process_WhenMultichannel_ProcessCorrectly()
public async Task ProcessAsync_CalledMultipleTimes_Serially_WillCompleteEverytime()
{
var segments = new List<SegmentData>();

var segments1 = new List<SegmentData>();
var segments2 = new List<SegmentData>();
var segments3 = new List<SegmentData>();

using var factory = WhisperFactory.FromPath(ggmlModelPath);
await using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithSegmentEventHandler(segments.Add)
.Build();

using var fileReader = File.OpenRead("multichannel.wav");
processor.Process(fileReader);
using var fileReader = File.OpenRead("kennedy.wav");
await foreach (var segment in processor.ProcessAsync(fileReader))
{
segments1.Add(segment);
}

segments.Should().HaveCountGreaterThanOrEqualTo(1);
using var fileReader2 = File.OpenRead("kennedy.wav");
await foreach (var segment in processor.ProcessAsync(fileReader2))
{
segments2.Add(segment);
}

using var fileReader3 = File.OpenRead("kennedy.wav");
await foreach (var segment in processor.ProcessAsync(fileReader3))
{
segments3.Add(segment);
}

segments1.Should().BeEquivalentTo(segments2);
segments2.Should().BeEquivalentTo(segments3);
}
}
154 changes: 154 additions & 0 deletions Whisper.net.Tests/ProcessFunctionalTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Licensed under the MIT license: https://opensource.org/licenses/MIT

using FluentAssertions;
using NUnit.Framework;
using Whisper.net.Ggml;

namespace Whisper.net.Tests;

public class ProcessFunctionalTests
{
private string ggmlModelPath = string.Empty;

[OneTimeSetUp]
public async Task SetupAsync()
{
ggmlModelPath = Path.GetTempFileName();
var model = await WhisperGgmlDownloader.GetGgmlModelAsync(GgmlType.Tiny);
using var fileWriter = File.OpenWrite(ggmlModelPath);
await model.CopyToAsync(fileWriter);
}

[OneTimeTearDown]
public void TearDown()
{
File.Delete(ggmlModelPath);
}

[Test]
public void TestHappyFlow()
{
var segments = new List<SegmentData>();
var progress = new List<int>();
var encoderBegins = new List<EncoderBeginData>();
using var factory = WhisperFactory.FromPath(ggmlModelPath);
using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithEncoderBeginHandler((e) =>
{
encoderBegins.Add(e);
return true;
})
.WithPrompt("I am Kennedy")
.WithProgressHandler(progress.Add)
.WithSegmentEventHandler(segments.Add)
.Build();

using var fileReader = File.OpenRead("kennedy.wav");
processor.Process(fileReader);

segments.Should().HaveCountGreaterThan(0);
encoderBegins.Should().HaveCount(1);
progress.Should().BeInAscendingOrder().And.HaveCountGreaterThan(1);

segments.Should().Contain(segmentData => segmentData.Text.Contains("nation should commit"));
}

[Test]
public void TestCancelEncoder()
{
var segments = new List<SegmentData>();
var encoderBegins = new List<EncoderBeginData>();
using var factory = WhisperFactory.FromPath(ggmlModelPath);
using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithEncoderBeginHandler((e) =>
{
encoderBegins.Add(e);
return false;
})
.WithSegmentEventHandler(segments.Add)
.Build();

using var fileReader = File.OpenRead("kennedy.wav");
processor.Process(fileReader);

segments.Should().HaveCount(0);
encoderBegins.Should().HaveCount(1);
}

[Test]
public async Task TestAutoDetectLanguageWithRomanian()
{
var segments = new List<SegmentData>();
var encoderBegins = new List<EncoderBeginData>();
using var factory = WhisperFactory.FromPath(ggmlModelPath);
using var processor = factory.CreateBuilder()
.WithLanguageDetection()
.WithEncoderBeginHandler((e) =>
{
encoderBegins.Add(e);
return true;
})
.Build();
using var fileReader = File.OpenRead("romana.wav");
await foreach (var segment in processor.ProcessAsync(fileReader))
{
segments.Add(segment);
}
segments.Should().HaveCountGreaterThan(0);
encoderBegins.Should().HaveCount(1);
segments.Should().AllSatisfy(s => s.Language.Should().Be("ro"));
segments.Should().Contain(segmentData => segmentData.Text.Contains("efectua"));
}

[Test]
public async Task Process_WhenMultichannel_ProcessCorrectly()
{
var segments = new List<SegmentData>();

using var factory = WhisperFactory.FromPath(ggmlModelPath);
await using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithSegmentEventHandler(segments.Add)
.Build();

using var fileReader = File.OpenRead("multichannel.wav");
processor.Process(fileReader);

segments.Should().HaveCountGreaterThanOrEqualTo(1);
}

[Test]
public async Task Process_CalledMultipleTimes_Serially_WillCompleteEverytime()
{

var segments1 = new List<SegmentData>();
var segments2 = new List<SegmentData>();
var segments3 = new List<SegmentData>();

OnSegmentEventHandler onNewSegment = segments1.Add;

using var factory = WhisperFactory.FromPath(ggmlModelPath);
await using var processor = factory.CreateBuilder()
.WithLanguage("en")
.WithSegmentEventHandler((s) => onNewSegment(s))
.Build();

using var fileReader1 = File.OpenRead("kennedy.wav");
processor.Process(fileReader1);

onNewSegment = segments2.Add;

using var fileReader2 = File.OpenRead("kennedy.wav");
processor.Process(fileReader2);

onNewSegment = segments3.Add;

using var fileReader3 = File.OpenRead("kennedy.wav");
processor.Process(fileReader3);

segments1.Should().BeEquivalentTo(segments2);
segments2.Should().BeEquivalentTo(segments3);
}
}
9 changes: 7 additions & 2 deletions Whisper.net/WhisperProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ public unsafe void Process(float[] samples)
try
{
processingSemaphore.Wait();
segmentIndex = 0;

NativeMethods.whisper_full_with_state(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length);
}
finally
Expand Down Expand Up @@ -153,10 +155,10 @@ void OnSegmentHandler(SegmentData segmentData)
resetEvent!.Set();
}

options.OnSegmentEventHandlers.Add(OnSegmentHandler);

try
{
options.OnSegmentEventHandlers.Add(OnSegmentHandler);

currentCancellationToken = cancellationToken;
var whisperTask = ProcessInternalAsync(samples, cancellationToken)
.ContinueWith(_ => resetEvent.Set(), cancellationToken, TaskContinuationOptions.None, TaskScheduler.Default);
Expand Down Expand Up @@ -225,7 +227,10 @@ private unsafe Task ProcessInternalAsync(float[] samples, CancellationToken canc
fixed (float* pData = samples)
{
processingSemaphore.Wait();
segmentIndex = 0;

var state = NativeMethods.whisper_init_state(currentWhisperContext);

try
{
NativeMethods.whisper_full_with_state(currentWhisperContext, state, whisperParams, (IntPtr)pData, samples.Length);
Expand Down

0 comments on commit 9b627f0

Please sign in to comment.