From c998e7606df81683b6bc865af7312a2fca5e9619 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Sun, 6 Apr 2025 16:50:58 +0100 Subject: [PATCH 1/9] WIP SpeechToText Impl --- Directory.Packages.props | 3 +- Whisper.net.Demo/Whisper.net.Demo.csproj | 12 +- .../WhisperSpeechToTextClient.cs | 176 ++++++++++++++++++ Whisper.net/Whisper.net.csproj | 6 +- nuget.config | 3 +- 5 files changed, 190 insertions(+), 10 deletions(-) create mode 100644 Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index 678583a82..9bb95877f 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -4,6 +4,7 @@ + @@ -11,7 +12,7 @@ - + diff --git a/Whisper.net.Demo/Whisper.net.Demo.csproj b/Whisper.net.Demo/Whisper.net.Demo.csproj index 8f452edfe..3f6146cd1 100644 --- a/Whisper.net.Demo/Whisper.net.Demo.csproj +++ b/Whisper.net.Demo/Whisper.net.Demo.csproj @@ -1,10 +1,7 @@ - - - - + + + + Exe @@ -14,6 +11,7 @@ + diff --git a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs new file mode 100644 index 000000000..fd4fa4a5e --- /dev/null +++ b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs @@ -0,0 +1,176 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.Extensions.AI; + +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace Whisper.net; + +public sealed class WhisperSpeechToTextClient : ISpeechToTextClient +{ + private readonly WhisperFactory _factory; + private WhisperProcessor? _processor; + + public WhisperSpeechToTextClient(string modelFileName) + { + this._factory = WhisperFactory.FromPath(modelFileName); + } + + public void Dispose() + { + if (this._processor != null) + { + this._processor.Dispose(); + } + + if (this._factory != null) + { + this._factory.Dispose(); + } + } + + public object? GetService(Type serviceType, object? serviceKey = null) + { + throw new NotImplementedException(); + } + + public async IAsyncEnumerable GetStreamingTextAsync(Stream audioSpeechStream, SpeechToTextOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (audioSpeechStream is null) + { + throw new ArgumentNullException(nameof(audioSpeechStream)); + } + + PrepareProcessor(options); + + var responseId = Guid.NewGuid().ToString(); + await foreach (var segment in this._processor!.ProcessAsync(audioSpeechStream, cancellationToken)) + { + if (cancellationToken.IsCancellationRequested) + { + break; + } + + yield return new SpeechToTextResponseUpdate(segment.Text) + { + ResponseId = responseId, + Kind = SpeechToTextResponseUpdateKind.TextUpdating, + RawRepresentation = segment, + StartTime = segment.Start, + EndTime = segment.End + }; + } + } + + public async Task GetTextAsync(Stream audioSpeechStream, SpeechToTextOptions? options = null, CancellationToken cancellationToken = default) + { + SpeechToTextResponse response = new(); + PrepareProcessor(options); + + StringBuilder fullTranscription = new(); + List segments = []; + + await foreach (var segment in this._processor!.ProcessAsync(audioSpeechStream, cancellationToken)) + { + if (cancellationToken.IsCancellationRequested) + { + break; + } + + response.StartTime ??= segment.Start; + response.EndTime = segment.End; + + segments.Add(segment); + fullTranscription.Append(segment.Text); + } + + response.ResponseId = Guid.NewGuid().ToString(); + response.RawRepresentation = segments; + response.Contents = [new TextContent(fullTranscription.ToString())]; + + return response; + } + + private void PrepareProcessor(SpeechToTextOptions? options) + { + if (this._processor is not null) + { + return; + } + + var processorBuilder = this._factory.CreateBuilder(); + if (options is not null) + { + if (!string.IsNullOrWhiteSpace(options?.SpeechLanguage)) + { + processorBuilder.WithLanguage(options!.SpeechLanguage!); + } + + if (GetAdditionalProperty("AudioContextSize", options!, out var audioContextSize)) + { + processorBuilder.WithAudioContextSize(audioContextSize); + } + + if (GetAdditionalProperty("BeamSearchSamplingStrategy", options!, out var beamSearchSamplingStrategy) && beamSearchSamplingStrategy) + { + processorBuilder.WithBeamSearchSamplingStrategy(); + } + + + /* + processorBuilder.WithDuration(options?.Duration ?? TimeSpan.MinValue); + processorBuilder.WithEncoderBeginHandler(options?.EncoderBeginHandler); + processorBuilder.WithEntropyThreshold(options?.EntropyThreshold ?? 0.0f); + + if (GetAdditionalProperty("GreedySamplingStrategy", options!, out var greedySamplingStrategy) && greedySamplingStrategy) + { + processorBuilder.WithGreedySamplingStrategy(); + } + + processorBuilder.WithLanguage(options?.Language ?? string.Empty); + processorBuilder.WithLanguageDetection(options?.LanguageDetection ?? false); + processorBuilder.WithLengthPenalty(options?.LengthPenalty ?? 0.0f); + processorBuilder.WithLogProbThreshold(options?.LogProbThreshold ?? 0.0f); + processorBuilder.WithMaxInitialTs(options?.MaxInitialTs ?? 0); + processorBuilder.WithMaxSegmentLength(options?.MaxSegmentLength ?? 0); + processorBuilder.WithMaxLastTextTokens(options?.MaxLastTextTokens ?? 0); + processorBuilder.WithMaxTokensPerSegment(options?.MaxTokensPerSegment ?? 0); + processorBuilder.WithNoContext(options?.NoContext ?? false); + processorBuilder.WithNoSpeechThreshold(options?.NoSpeechThreshold ?? 0.0f); + processorBuilder.WithOffset(options?.Offset ?? 0); + processorBuilder.WithOpenVinoEncoder(options?.OpenVinoEncoderPath, options?.OpenVinoDevice, options?.OpenVinoCachePath); + processorBuilder.WithoutSuppressBlank(); + processorBuilder.WithoutStringPool(); + processorBuilder.WithPrintProgress(options?.PrintProgress ?? false); + processorBuilder.WithPrintTimestamps(options?.PrintTimestamps ?? false); + processorBuilder.WithPrintSpecialTokens(options?.PrintSpecialTokens ?? false); + processorBuilder.WithPrintResults(options?.PrintResults ?? false); + processorBuilder.WithProbabilities(options?.Probabilities ?? false); + processorBuilder.WithProgressHandler(options?.ProgressHandler); + processorBuilder.WithSegmentEventHandler(options?.SegmentEventHandler); + processorBuilder.WithStringPool(options?.StringPool ?? string.Empty); + processorBuilder.WithTemperature(options?.Temperature ?? 0.0f); + processorBuilder.WithTemperatureInc(options?.TemperatureInc ?? 0.0f); + processorBuilder.WithThreads(options?.Threads ?? 0); + processorBuilder.WithTokenTimestamps(); + processorBuilder.WithTokenTimestampsSumThreshold(options?.TokenTimestampsSumThreshold ?? 0.0f); + processorBuilder.WithTokenTimestampsThreshold(options?.TokenTimestampsThreshold ?? 0.0f); + */ + } + + this._processor = processorBuilder.Build(); + } + + private static bool GetAdditionalProperty(string propertyName, SpeechToTextOptions options, out T? value) + { + if (options.AdditionalProperties?.TryGetValue(propertyName, out value) ?? false) + { + return true; + } + + value = default; + return false; + } +} diff --git a/Whisper.net/Whisper.net.csproj b/Whisper.net/Whisper.net.csproj index 586bac394..2784aeb90 100755 --- a/Whisper.net/Whisper.net.csproj +++ b/Whisper.net/Whisper.net.csproj @@ -1,4 +1,4 @@ - + enable @@ -46,4 +46,8 @@ + + + + diff --git a/nuget.config b/nuget.config index 0cede15d4..e66c0e196 100644 --- a/nuget.config +++ b/nuget.config @@ -1,10 +1,11 @@ - + + From 212436b0f43326c3111ec7b005ef5574ab96a065 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Sun, 6 Apr 2025 17:07:12 +0100 Subject: [PATCH 2/9] Extensions Added --- Whisper.net/Internals/Native/INativeCuda.cs | 1 - .../LibraryLoader/NativeLibraryLoader.cs | 2 -- .../LibraryLoader/UniversalLibraryLoader.cs | 1 - .../SpeechToTextOptionsExtensions.cs | 34 +++++++++++++++++++ .../WhisperSpeechToTextClient.cs | 5 ++- Whisper.net/WhisperFactoryOptions.cs | 1 - .../ProcessAsyncFunctionalTests.cs | 2 -- .../Whisper.net.Tests/SegmentDataComparer.cs | 1 - .../DependencyWalker/INativeLibraryLoader.cs | 2 ++ .../DependencyWalker/NativeLibraryChecker.cs | 2 ++ 10 files changed, 40 insertions(+), 11 deletions(-) create mode 100644 Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs diff --git a/Whisper.net/Internals/Native/INativeCuda.cs b/Whisper.net/Internals/Native/INativeCuda.cs index e3bcf22f5..43b0f110c 100644 --- a/Whisper.net/Internals/Native/INativeCuda.cs +++ b/Whisper.net/Internals/Native/INativeCuda.cs @@ -9,5 +9,4 @@ internal interface INativeCuda : IDisposable public delegate int cudaGetDeviceCount(out int count); cudaGetDeviceCount CudaGetDeviceCount { get; } - } diff --git a/Whisper.net/LibraryLoader/NativeLibraryLoader.cs b/Whisper.net/LibraryLoader/NativeLibraryLoader.cs index 5884349c5..2dcac6d40 100644 --- a/Whisper.net/LibraryLoader/NativeLibraryLoader.cs +++ b/Whisper.net/LibraryLoader/NativeLibraryLoader.cs @@ -250,7 +250,6 @@ private static bool IsRuntimeSupported(RuntimeLibrary runtime, string platform, WhisperLogger.Log(WhisperLogLevel.Debug, $"Runtime directory for {library} not found in {runtimePath}"); } } - } } @@ -272,5 +271,4 @@ private static bool IsRuntimeSupported(RuntimeLibrary runtime, string platform, } #endif } - } diff --git a/Whisper.net/LibraryLoader/UniversalLibraryLoader.cs b/Whisper.net/LibraryLoader/UniversalLibraryLoader.cs index 40dd9a658..e73172052 100644 --- a/Whisper.net/LibraryLoader/UniversalLibraryLoader.cs +++ b/Whisper.net/LibraryLoader/UniversalLibraryLoader.cs @@ -1,7 +1,6 @@ // Licensed under the MIT license: https://opensource.org/licenses/MIT #if !NETSTANDARD -using System.ComponentModel; using System.Reflection; using System.Runtime.InteropServices; diff --git a/Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs b/Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs new file mode 100644 index 000000000..edbe95742 --- /dev/null +++ b/Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs @@ -0,0 +1,34 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + +using Microsoft.Extensions.AI; + +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace Whisper.net; + +public static class SpeechToTextOptionsExtensions +{ + internal const string BeamSearchSamplingStrategyKey = "BeamSearchSamplingStrategy"; + internal const string AudioContextSizeKey = "AudioContextSize"; + + public static SpeechToTextOptions WithLanguage(this SpeechToTextOptions options, string language) + { + options.SpeechLanguage = language; + return options; + } + + public static SpeechToTextOptions WithBeamSearchSamplingStrategy(this SpeechToTextOptions options) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[BeamSearchSamplingStrategyKey] = true; + + return options; + } + + public static SpeechToTextOptions WithAudioContextSize(this SpeechToTextOptions options, int audioContextSize) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[AudioContextSizeKey] = audioContextSize; + return options; + } +} diff --git a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs index fd4fa4a5e..ec8a30dcb 100644 --- a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs +++ b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs @@ -108,17 +108,16 @@ private void PrepareProcessor(SpeechToTextOptions? options) processorBuilder.WithLanguage(options!.SpeechLanguage!); } - if (GetAdditionalProperty("AudioContextSize", options!, out var audioContextSize)) + if (GetAdditionalProperty(SpeechToTextOptionsExtensions.AudioContextSizeKey, options!, out var audioContextSize)) { processorBuilder.WithAudioContextSize(audioContextSize); } - if (GetAdditionalProperty("BeamSearchSamplingStrategy", options!, out var beamSearchSamplingStrategy) && beamSearchSamplingStrategy) + if (GetAdditionalProperty(SpeechToTextOptionsExtensions.BeamSearchSamplingStrategyKey, options!, out var beamSearchSamplingStrategy) && beamSearchSamplingStrategy) { processorBuilder.WithBeamSearchSamplingStrategy(); } - /* processorBuilder.WithDuration(options?.Duration ?? TimeSpan.MinValue); processorBuilder.WithEncoderBeginHandler(options?.EncoderBeginHandler); diff --git a/Whisper.net/WhisperFactoryOptions.cs b/Whisper.net/WhisperFactoryOptions.cs index 7d6f92482..dbd7e7216 100644 --- a/Whisper.net/WhisperFactoryOptions.cs +++ b/Whisper.net/WhisperFactoryOptions.cs @@ -90,5 +90,4 @@ public WhisperFactoryOptions() /// By default, it is false and the model is loaded right away. /// public bool DelayInitialization { get; set; } - } diff --git a/tests/Whisper.net.Tests/ProcessAsyncFunctionalTests.cs b/tests/Whisper.net.Tests/ProcessAsyncFunctionalTests.cs index 8080d4a06..8c01934a4 100644 --- a/tests/Whisper.net.Tests/ProcessAsyncFunctionalTests.cs +++ b/tests/Whisper.net.Tests/ProcessAsyncFunctionalTests.cs @@ -155,7 +155,5 @@ public async Task ProcessAsync_CalledMultipleTimes_Serially_WillCompleteEverytim Assert.True(segments1.SequenceEqual(segments2, new SegmentDataComparer())); Assert.True(segments2.SequenceEqual(segments3, new SegmentDataComparer())); - } - } diff --git a/tests/Whisper.net.Tests/SegmentDataComparer.cs b/tests/Whisper.net.Tests/SegmentDataComparer.cs index 80ec38ff2..e7e071693 100644 --- a/tests/Whisper.net.Tests/SegmentDataComparer.cs +++ b/tests/Whisper.net.Tests/SegmentDataComparer.cs @@ -19,5 +19,4 @@ public int GetHashCode(SegmentData obj) return obj.Text.GetHashCode(); } } - } diff --git a/tools/WhisperNetDependencyChecker/DependencyWalker/INativeLibraryLoader.cs b/tools/WhisperNetDependencyChecker/DependencyWalker/INativeLibraryLoader.cs index 30578913a..8ca6da344 100644 --- a/tools/WhisperNetDependencyChecker/DependencyWalker/INativeLibraryLoader.cs +++ b/tools/WhisperNetDependencyChecker/DependencyWalker/INativeLibraryLoader.cs @@ -1,3 +1,5 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + namespace WhisperNetDependencyChecker.DependencyWalker; internal interface INativeLibraryLoader diff --git a/tools/WhisperNetDependencyChecker/DependencyWalker/NativeLibraryChecker.cs b/tools/WhisperNetDependencyChecker/DependencyWalker/NativeLibraryChecker.cs index 13bbdf6e6..229f43202 100644 --- a/tools/WhisperNetDependencyChecker/DependencyWalker/NativeLibraryChecker.cs +++ b/tools/WhisperNetDependencyChecker/DependencyWalker/NativeLibraryChecker.cs @@ -1,3 +1,5 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + namespace WhisperNetDependencyChecker.DependencyWalker; using System; From e7a6285010210ffd6077580c48c77caff2063abb Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Sun, 6 Apr 2025 17:54:21 +0100 Subject: [PATCH 3/9] Update impl for MEAI SpeechToText --- Whisper.net.Demo/Program.cs | 24 + .../SpeechToTextOptionsExtensions.cs | 783 +++++++++++++++++- .../WhisperSpeechToTextClient.cs | 114 +-- 3 files changed, 819 insertions(+), 102 deletions(-) diff --git a/Whisper.net.Demo/Program.cs b/Whisper.net.Demo/Program.cs index b923cb20c..9835d8350 100755 --- a/Whisper.net.Demo/Program.cs +++ b/Whisper.net.Demo/Program.cs @@ -1,6 +1,7 @@ // Licensed under the MIT license: https://opensource.org/licenses/MIT using CommandLine; +using Microsoft.Extensions.AI; using Whisper.net; using Whisper.net.Ggml; using Whisper.net.Wave; @@ -26,6 +27,7 @@ async Task Demo(Options opt) case "transcribe": case "translate": await FullDetection(opt); + await FullDetectionSpeechToText(opt); break; default: Console.WriteLine("Unknown command"); @@ -78,6 +80,28 @@ async Task FullDetection(Options opt) } } +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. +async Task FullDetectionSpeechToText(Options opt) +{ + // Same factory can be used by multiple task to create processors. + using var speechToTextClient = new WhisperSpeechToTextClient(opt.ModelName); + + var options = new SpeechToTextOptions().WithLanguage(opt.Language); + + if (opt.Command == "translate") + { + options.WithTranslate(); + } + + using var fileStream = File.OpenRead(opt.FileName); + + await foreach (var segment in speechToTextClient.GetStreamingTextAsync(fileStream, options, CancellationToken.None)) + { + Console.WriteLine($"New Segment: {segment.StartTime} ==> {segment.EndTime} : {segment.Text}"); + } +} +#pragma warning restore MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + /// /// The options for this Demo /// diff --git a/Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs b/Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs index edbe95742..29e3ab808 100644 --- a/Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs +++ b/Whisper.net/SpeechToTextClient/SpeechToTextOptionsExtensions.cs @@ -8,15 +8,77 @@ namespace Whisper.net; public static class SpeechToTextOptionsExtensions { - internal const string BeamSearchSamplingStrategyKey = "BeamSearchSamplingStrategy"; - internal const string AudioContextSizeKey = "AudioContextSize"; + private const string BeamSearchSamplingStrategyKey = "BeamSearchSamplingStrategy"; + private const string AudioContextSizeKey = "AudioContextSize"; + private const string DurationKey = "Duration"; + private const string EncoderBeginHandlerKey = "EncoderBeginHandler"; + private const string EntropyThresholdKey = "EntropyThreshold"; + private const string GreedySamplingStrategyKey = "GreedySamplingStrategy"; + private const string LanguageKey = "Language"; + private const string LanguageDetectionKey = "LanguageDetection"; + private const string LengthPenaltyKey = "LengthPenalty"; + private const string LogProbThresholdKey = "LogProbThreshold"; + private const string MaxInitialTsKey = "MaxInitialTs"; + private const string MaxSegmentLengthKey = "MaxSegmentLength"; + private const string MaxLastTextTokensKey = "MaxLastTextTokens"; + private const string MaxTokensPerSegmentKey = "MaxTokensPerSegment"; + private const string NoContextKey = "NoContext"; + private const string NoSpeechThresholdKey = "NoSpeechThreshold"; + private const string OffsetKey = "Offset"; + private const string OpenVinoEncoderPathKey = "OpenVinoEncoderPath"; + private const string OpenVinoDeviceKey = "OpenVinoDevice"; + private const string OpenVinoCachePathKey = "OpenVinoCachePath"; + private const string SuppressBlankKey = "SuppressBlank"; + private const string StringPoolKey = "StringPool"; + private const string PrintProgressKey = "PrintProgress"; + private const string PrintTimestampsKey = "PrintTimestamps"; + private const string PrintSpecialTokensKey = "PrintSpecialTokens"; + private const string PrintResultsKey = "PrintResults"; + private const string ProbabilitiesKey = "Probabilities"; + private const string ProgressHandlerKey = "ProgressHandler"; + private const string SegmentEventHandlerKey = "SegmentEventHandler"; + private const string TemperatureKey = "Temperature"; + private const string TemperatureIncKey = "TemperatureInc"; + private const string ThreadsKey = "Threads"; + private const string TokenTimestampsKey = "TokenTimestamps"; + private const string TokenTimestampsSumThresholdKey = "TokenTimestampsSumThreshold"; + private const string TokenTimestampsThresholdKey = "TokenTimestampsThreshold"; + /// + /// Configures the processor with the language to be used for detection. + /// + /// The options to configure. + /// The language (2 letters) to be used. + /// The same options instance for chaining. + /// + /// Default value is "en". + /// Example: "en", "ro" + /// public static SpeechToTextOptions WithLanguage(this SpeechToTextOptions options, string language) { options.SpeechLanguage = language; return options; } + /// + /// Configures the processor to translate the text to English. + /// + /// The options to configure. + /// The same options instance for chaining. + /// + /// If not specified, the processor will just transcribe it. + /// + public static SpeechToTextOptions WithTranslate(this SpeechToTextOptions options) + { + options.TextLanguage = "English"; + return options; + } + + /// + /// Configures the processor to use the Beam Search Sampling Strategy. + /// + /// The options to configure. + /// The same options instance for chaining. public static SpeechToTextOptions WithBeamSearchSamplingStrategy(this SpeechToTextOptions options) { options.AdditionalProperties ??= []; @@ -25,10 +87,727 @@ public static SpeechToTextOptions WithBeamSearchSamplingStrategy(this SpeechToTe return options; } + /// + /// [EXPERIMENTAL] Configures the processor to override the audio context size. + /// + /// The options to configure. + /// Audio context size to be overridden + /// The same options instance for chaining. + /// + /// Quality might be degraded while performance might be improved. + /// public static SpeechToTextOptions WithAudioContextSize(this SpeechToTextOptions options, int audioContextSize) { options.AdditionalProperties ??= []; options.AdditionalProperties[AudioContextSizeKey] = audioContextSize; return options; } + + /// + /// Configures the processor with the duration in the audio to which it processes. + /// + /// The options to configure. + /// Duration in the audio. + /// The same options instance for chaining. + /// + /// If not specified, the processing is happening until the end. + /// + public static SpeechToTextOptions WithDuration(this SpeechToTextOptions options, TimeSpan duration) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[DurationKey] = duration; + return options; + } + + /// + /// Adds a which will be called when the encoder begins processing. + /// + /// The options to configure. + /// The event handler to be added. + /// The same options instance for chaining. + public static SpeechToTextOptions WithEncoderBeginHandler(this SpeechToTextOptions options, OnEncoderBeginEventHandler handler) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[EncoderBeginHandlerKey] = handler; + return options; + } + + /// + /// Configures the processor with an entropy threshold for decoder fallback. + /// + /// The options to configure. + /// The entropy threshold. + /// The same options instance for chaining. + /// + /// Default value is 2.4f. + /// + public static SpeechToTextOptions WithEntropyThreshold(this SpeechToTextOptions options, float threshold) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[EntropyThresholdKey] = threshold; + return options; + } + + /// + /// Configures the processor to use the Greedy Sampling strategy. + /// + /// The options to configure. + /// The same options instance for chaining. + public static SpeechToTextOptions WithGreedySamplingStrategy(this SpeechToTextOptions options) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[GreedySamplingStrategyKey] = true; + return options; + } + + /// + /// Configures the processor with the language to be used for detection. + /// + /// The options to configure. + /// The language (2 letters) to be used. + /// The same options instance for chaining. + /// + /// Default value is "en". + /// Example: "en", "ro" + /// + public static SpeechToTextOptions WithWhisperLanguage(this SpeechToTextOptions options, string language) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[LanguageKey] = language; + return options; + } + + /// + /// Configures the processor to auto-detect the language based on initial samples. + /// + /// The options to configure. + /// Whether to enable language detection. + /// The same options instance for chaining. + /// + /// Note: Processing time will slightly increase. + /// + public static SpeechToTextOptions WithLanguageDetection(this SpeechToTextOptions options, bool enableDetection = true) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[LanguageDetectionKey] = enableDetection; + return options; + } + + /// + /// Configures the processor with a value indicating the length penalty (alpha). + /// + /// The options to configure. + /// The length penalty value. + /// The same options instance for chaining. + /// + /// If not specified, the processor will use simple length normalization by default. + /// More information about the length penalty can be found here: https://arxiv.org/abs/1609.08144. + /// + public static SpeechToTextOptions WithLengthPenalty(this SpeechToTextOptions options, float penalty) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[LengthPenaltyKey] = penalty; + return options; + } + + /// + /// Configures the processor with a average log probability threshold over sampled tokens. + /// + /// The options to configure. + /// The average log probability threshold. + /// The same options instance for chaining. + /// + /// Default value is -1.0f. + /// + public static SpeechToTextOptions WithLogProbThreshold(this SpeechToTextOptions options, float threshold) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[LogProbThresholdKey] = threshold; + return options; + } + + /// + /// Configures the processor with a value indicating that the initial timestamp cannot be later than this. + /// + /// The options to configure. + /// The initial max timestamp. + /// The same options instance for chaining. + /// + /// If not specified, default value is: 1f. + /// + public static SpeechToTextOptions WithMaxInitialTs(this SpeechToTextOptions options, float maxInitialTs) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[MaxInitialTsKey] = maxInitialTs; + return options; + } + + /// + /// Configures the processor with the maximum segment length in characters. + /// + /// The options to configure. + /// The maximum segment length in characters. + /// The same options instance for chaining. + public static SpeechToTextOptions WithMaxSegmentLength(this SpeechToTextOptions options, int maxLength) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[MaxSegmentLengthKey] = maxLength; + return options; + } + + /// + /// Configures the processor with the max number of tokens to be used from the previous text as prompt for the decoder. + /// + /// The options to configure. + /// The max number of tokens to be used. + /// The same options instance for chaining. + /// + /// If not specified, a number of 16384 tokens is used. + /// + public static SpeechToTextOptions WithMaxLastTextTokens(this SpeechToTextOptions options, int maxTokens) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[MaxLastTextTokensKey] = maxTokens; + return options; + } + + /// + /// Configures the processor with the maximum number of tokens per segment. + /// + /// The options to configure. + /// The maximum number of tokens per segment. + /// The same options instance for chaining. + public static SpeechToTextOptions WithMaxTokensPerSegment(this SpeechToTextOptions options, int maxTokens) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[MaxTokensPerSegmentKey] = maxTokens; + return options; + } + + /// + /// Configures the processor to not use past transformation (if any) as the initial prompt for a newer processing. + /// + /// The options to configure. + /// Whether to disable context. + /// The same options instance for chaining. + /// + /// If not specified, the processor use part transformations as initial prompt for newer processing. + /// + public static SpeechToTextOptions WithNoContext(this SpeechToTextOptions options, bool noContext = true) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[NoContextKey] = noContext; + return options; + } + + /// + /// [EXPERIMENTAL] Configures the processor with a no_speech probability. + /// + /// The options to configure. + /// The no_speech probability + /// The same options instance for chaining. + /// + /// Default value is 0.6f. + /// + public static SpeechToTextOptions WithNoSpeechThreshold(this SpeechToTextOptions options, float threshold) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[NoSpeechThresholdKey] = threshold; + return options; + } + + /// + /// Configures the processor with the start time in the audio from which it starts the processing. + /// + /// The options to configure. + /// Offset in the audio. + /// The same options instance for chaining. + /// + /// If not specified, the processing is happening from the beginning. + /// + public static SpeechToTextOptions WithOffset(this SpeechToTextOptions options, TimeSpan offset) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[OffsetKey] = offset; + return options; + } + + /// + /// Configures the processor to use OpenVINO for the encoder part. + /// + /// The options to configure. + /// Path to the OpenVINO encoder model. + /// The device to use for OpenVINO. + /// Path to the OpenVINO cache directory. + /// The same options instance for chaining. + public static SpeechToTextOptions WithOpenVinoEncoder(this SpeechToTextOptions options, string? encoderPath, string? device = null, string? cachePath = null) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[OpenVinoEncoderPathKey] = encoderPath; + if (device != null) + { + options.AdditionalProperties[OpenVinoDeviceKey] = device; + } + if (cachePath != null) + { + options.AdditionalProperties[OpenVinoCachePathKey] = cachePath; + } + return options; + } + + /// + /// Configures the processor to NOT suppress blank outputs. + /// + /// The options to configure. + /// The same options instance for chaining. + /// + /// If not specified, blanks are automatically suppressed. + /// + public static SpeechToTextOptions WithoutSuppressBlank(this SpeechToTextOptions options) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[SuppressBlankKey] = false; + return options; + } + + /// + /// Disables the string pooling. + /// + /// The options to configure. + /// The same options instance for chaining. + /// + /// This will disable the pooling of strings that are generated (have effect only if was called). + /// + public static SpeechToTextOptions WithoutStringPool(this SpeechToTextOptions options) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[StringPoolKey] = false; + return options; + } + + /// + /// Configures the processor to print progress information. + /// + /// The options to configure. + /// Whether to print progress information. + /// The same options instance for chaining. + /// + /// If not specified, the processor will not print progress information. + /// + public static SpeechToTextOptions WithPrintProgress(this SpeechToTextOptions options, bool printProgress = true) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[PrintProgressKey] = printProgress; + return options; + } + + /// + /// Configures the processor to print timestamps for each segment to stdout. + /// + /// The options to configure. + /// Whether to print timestamps. + /// The same options instance for chaining. + /// + /// This option is available only if is configured. + /// If not specified, the processor will print also timestamps. + /// + public static SpeechToTextOptions WithPrintTimestamps(this SpeechToTextOptions options, bool printTimestamps = true) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[PrintTimestampsKey] = printTimestamps; + return options; + } + + /// + /// Configures the processor to print special tokens to stdout. + /// + /// The options to configure. + /// Whether to print special tokens. + /// The same options instance for chaining. + /// + /// This option is available only if is configured. + /// + public static SpeechToTextOptions WithPrintSpecialTokens(this SpeechToTextOptions options, bool printSpecialTokens = true) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[PrintSpecialTokensKey] = printSpecialTokens; + return options; + } + + /// + /// Configures the processor to print results to stdout. + /// + /// The options to configure. + /// Whether to print results. + /// The same options instance for chaining. + /// + /// If not specified, the processor will not print results to stdout. + /// + public static SpeechToTextOptions WithPrintResults(this SpeechToTextOptions options, bool printResults = true) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[PrintResultsKey] = printResults; + return options; + } + + /// + /// Configures the processor to return probabilities during segment decoding , and . + /// + /// The options to configure. + /// Whether to enable probabilities. + /// The same options instance for chaining. + public static SpeechToTextOptions WithProbabilities(this SpeechToTextOptions options, bool enableProbabilities = true) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[ProbabilitiesKey] = enableProbabilities; + return options; + } + + /// + /// Adds a which will be called to report progress. + /// + /// The options to configure. + /// The event handler to be added. + /// The same options instance for chaining. + public static SpeechToTextOptions WithProgressHandler(this SpeechToTextOptions options, OnProgressHandler handler) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[ProgressHandlerKey] = handler; + return options; + } + + /// + /// Adds a which will be called every time a new segment is detected. + /// + /// The options to configure. + /// The event handler to be added. + /// The same options instance for chaining. + public static SpeechToTextOptions WithSegmentEventHandler(this SpeechToTextOptions options, OnSegmentEventHandler handler) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[SegmentEventHandlerKey] = handler; + return options; + } + + /// + /// Adds the functionality of pooling the strings that are generated reducing the number of allocations. + /// + /// The options to configure. + /// The string pool to use. + /// The same options instance for chaining. + /// + /// When using this option designed for high-performance use-cases, + /// ensure that you're returning the object back to the + /// using the method . + /// + public static SpeechToTextOptions WithStringPool(this SpeechToTextOptions options, IStringPool stringPool) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[StringPoolKey] = stringPool; + return options; + } + + /// + /// Configures the processor with a temperature for sampling. + /// + /// The options to configure. + /// The temperature value. + /// The same options instance for chaining. + /// + /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + /// Default value is 0.0f. + /// + public static SpeechToTextOptions WithTemperature(this SpeechToTextOptions options, float temperature) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[TemperatureKey] = temperature; + return options; + } + + /// + /// Configures the processor with a temperature to increase when falling back. + /// + /// The options to configure. + /// The temperature to increase when falling back. + /// The same options instance for chaining. + /// + /// Falling back can happen when the decoding fails to meet either of the thresholds in: , or . + /// Default value is 0.2f. + /// + public static SpeechToTextOptions WithTemperatureInc(this SpeechToTextOptions options, float temperatureInc) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[TemperatureIncKey] = temperatureInc; + return options; + } + + /// + /// Configures the processor to use the specified number of threads. + /// + /// The options to configure. + /// The number of threads to be used during encoding and decoding. + /// The same options instance for chaining. + /// + /// If not specified, the same number as the hardware threads that the underlying hardware can support concurrently is used. + /// + public static SpeechToTextOptions WithThreads(this SpeechToTextOptions options, int threads) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[ThreadsKey] = threads; + return options; + } + + /// + /// [EXPERIMENTAL] Configures the processor to use token-level timestamps. + /// + /// The options to configure. + /// The same options instance for chaining. + /// + /// If not specified, the processor will not use token timestamps. + /// + public static SpeechToTextOptions WithTokenTimestamps(this SpeechToTextOptions options) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[TokenTimestampsKey] = true; + return options; + } + + /// + /// Configures the processor to use the specified SUM probability threshold for token timestamps. + /// + /// The options to configure. + /// Probability SUM threshold to be used for token-level timestamps. + /// The same options instance for chaining. + /// + /// Default value is 0.01f. + /// This option have effect only together with + /// + public static SpeechToTextOptions WithTokenTimestampsSumThreshold(this SpeechToTextOptions options, float threshold) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[TokenTimestampsSumThresholdKey] = threshold; + return options; + } + + /// + /// Configures the processor to use the specified probability threshold for token timestamps. + /// + /// The options to configure. + /// Probability threshold to be used for token-level timestamps. + /// The same options instance for chaining. + /// + /// Default value is 0.01f. + /// This option have effect only together with + /// + public static SpeechToTextOptions WithTokenTimestampsThreshold(this SpeechToTextOptions options, float threshold) + { + options.AdditionalProperties ??= []; + options.AdditionalProperties[TokenTimestampsThresholdKey] = threshold; + return options; + } + + internal static WhisperProcessor BuildWhisperProcessor(this SpeechToTextOptions? options, WhisperFactory factory) + { + var processorBuilder = factory.CreateBuilder(); + + if (options is null) + { + return processorBuilder.Build(); + } + + if (!string.IsNullOrWhiteSpace(options?.SpeechLanguage)) + { + processorBuilder.WithLanguage(options!.SpeechLanguage!); + } + + if (GetAdditionalProperty(AudioContextSizeKey, options!, out var audioContextSize)) + { + processorBuilder.WithAudioContextSize(audioContextSize); + } + + if (GetAdditionalProperty(BeamSearchSamplingStrategyKey, options!, out var beamSearchSamplingStrategy) && beamSearchSamplingStrategy) + { + processorBuilder.WithBeamSearchSamplingStrategy(); + } + + if (GetAdditionalProperty(DurationKey, options!, out var duration)) + { + processorBuilder.WithDuration(duration); + } + + if (GetAdditionalProperty(EncoderBeginHandlerKey, options!, out var encoderBeginHandler) && encoderBeginHandler != null) + { + processorBuilder.WithEncoderBeginHandler(encoderBeginHandler); + } + + if (GetAdditionalProperty(EntropyThresholdKey, options!, out var entropyThreshold)) + { + processorBuilder.WithEntropyThreshold(entropyThreshold); + } + + if (GetAdditionalProperty(GreedySamplingStrategyKey, options!, out var greedySamplingStrategy) && greedySamplingStrategy) + { + processorBuilder.WithGreedySamplingStrategy(); + } + + if (GetAdditionalProperty(LanguageKey, options!, out var language) && !string.IsNullOrEmpty(language)) + { + processorBuilder.WithLanguage(language); + } + + if (GetAdditionalProperty(LanguageDetectionKey, options!, out var languageDetection) && languageDetection) + { + processorBuilder.WithLanguageDetection(); + } + + if (GetAdditionalProperty(LengthPenaltyKey, options!, out var lengthPenalty)) + { + processorBuilder.WithLengthPenalty(lengthPenalty); + } + + if (GetAdditionalProperty(LogProbThresholdKey, options!, out var logProbThreshold)) + { + processorBuilder.WithLogProbThreshold(logProbThreshold); + } + + if (GetAdditionalProperty(MaxInitialTsKey, options!, out var maxInitialTs)) + { + processorBuilder.WithMaxInitialTs(maxInitialTs); + } + + if (GetAdditionalProperty(MaxSegmentLengthKey, options!, out var maxSegmentLength)) + { + processorBuilder.WithMaxSegmentLength(maxSegmentLength); + } + + if (GetAdditionalProperty(MaxLastTextTokensKey, options!, out var maxLastTextTokens)) + { + processorBuilder.WithMaxLastTextTokens(maxLastTextTokens); + } + + if (GetAdditionalProperty(MaxTokensPerSegmentKey, options!, out var maxTokensPerSegment)) + { + processorBuilder.WithMaxTokensPerSegment(maxTokensPerSegment); + } + + if (GetAdditionalProperty(NoContextKey, options!, out var noContext) && noContext) + { + processorBuilder.WithNoContext(); + } + + if (GetAdditionalProperty(NoSpeechThresholdKey, options!, out var noSpeechThreshold)) + { + processorBuilder.WithNoSpeechThreshold(noSpeechThreshold); + } + + if (GetAdditionalProperty(OffsetKey, options!, out var offset)) + { + processorBuilder.WithOffset(offset); + } + + if (GetAdditionalProperty(OpenVinoEncoderPathKey, options!, out var openVinoEncoderPath)) + { + GetAdditionalProperty(OpenVinoDeviceKey, options!, out var openVinoDevice); + GetAdditionalProperty(OpenVinoCachePathKey, options!, out var openVinoCachePath); + processorBuilder.WithOpenVinoEncoder(openVinoEncoderPath, openVinoDevice, openVinoCachePath); + } + + if (GetAdditionalProperty(SuppressBlankKey, options!, out var suppressBlank) && !suppressBlank) + { + processorBuilder.WithoutSuppressBlank(); + } + + if (GetAdditionalProperty(StringPoolKey, options!, out var useStringPool) && !useStringPool) + { + processorBuilder.WithoutStringPool(); + } + + if (GetAdditionalProperty(PrintProgressKey, options!, out var printProgress) && printProgress) + { + processorBuilder.WithPrintProgress(); + } + + if (GetAdditionalProperty(PrintTimestampsKey, options!, out var printTimestamps) && printTimestamps) + { + processorBuilder.WithPrintTimestamps(); + } + + if (GetAdditionalProperty(PrintSpecialTokensKey, options!, out var printSpecialTokens) && printSpecialTokens) + { + processorBuilder.WithPrintSpecialTokens(); + } + + if (GetAdditionalProperty(PrintResultsKey, options!, out var printResults) && printResults) + { + processorBuilder.WithPrintResults(); + } + + if (GetAdditionalProperty(ProbabilitiesKey, options!, out var probabilities) && probabilities) + { + processorBuilder.WithProbabilities(); + } + + if (GetAdditionalProperty(ProgressHandlerKey, options!, out var progressHandler) && progressHandler != null) + { + processorBuilder.WithProgressHandler(progressHandler); + } + + if (GetAdditionalProperty(SegmentEventHandlerKey, options!, out var segmentEventHandler) && segmentEventHandler != null) + { + processorBuilder.WithSegmentEventHandler(segmentEventHandler); + } + + if (GetAdditionalProperty(StringPoolKey, options!, out var stringPool) && stringPool != null) + { + processorBuilder.WithStringPool(stringPool); + } + + if (GetAdditionalProperty(TemperatureKey, options!, out var temperature)) + { + processorBuilder.WithTemperature(temperature); + } + + if (GetAdditionalProperty(TemperatureIncKey, options!, out var temperatureInc)) + { + processorBuilder.WithTemperatureInc(temperatureInc); + } + + if (GetAdditionalProperty(ThreadsKey, options!, out var threads)) + { + processorBuilder.WithThreads(threads); + } + + if (GetAdditionalProperty(TokenTimestampsKey, options!, out var tokenTimestamps) && tokenTimestamps) + { + processorBuilder.WithTokenTimestamps(); + } + + if (GetAdditionalProperty(TokenTimestampsSumThresholdKey, options!, out var tokenTimestampsSumThreshold)) + { + processorBuilder.WithTokenTimestampsSumThreshold(tokenTimestampsSumThreshold); + } + + if (GetAdditionalProperty(TokenTimestampsThresholdKey, options!, out var tokenTimestampsThreshold)) + { + processorBuilder.WithTokenTimestampsThreshold(tokenTimestampsThreshold); + } + + if (!string.IsNullOrWhiteSpace(options?.TextLanguage)) + { + processorBuilder.WithTranslate(); + } + + return processorBuilder.Build(); + } + + private static bool GetAdditionalProperty(string propertyName, SpeechToTextOptions options, out T? value) + { + if (options.AdditionalProperties?.TryGetValue(propertyName, out value) ?? false) + { + return true; + } + + value = default; + return false; + } } diff --git a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs index ec8a30dcb..f322792f3 100644 --- a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs +++ b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs @@ -8,27 +8,15 @@ namespace Whisper.net; -public sealed class WhisperSpeechToTextClient : ISpeechToTextClient +public sealed class WhisperSpeechToTextClient(string modelFileName) : ISpeechToTextClient { - private readonly WhisperFactory _factory; + private readonly WhisperFactory _factory = WhisperFactory.FromPath(modelFileName); private WhisperProcessor? _processor; - public WhisperSpeechToTextClient(string modelFileName) - { - this._factory = WhisperFactory.FromPath(modelFileName); - } - public void Dispose() { - if (this._processor != null) - { - this._processor.Dispose(); - } - - if (this._factory != null) - { - this._factory.Dispose(); - } + _processor?.Dispose(); + _factory?.Dispose(); } public object? GetService(Type serviceType, object? serviceKey = null) @@ -43,10 +31,10 @@ public async IAsyncEnumerable GetStreamingTextAsync( throw new ArgumentNullException(nameof(audioSpeechStream)); } - PrepareProcessor(options); + this._processor ??= options.BuildWhisperProcessor(_factory); var responseId = Guid.NewGuid().ToString(); - await foreach (var segment in this._processor!.ProcessAsync(audioSpeechStream, cancellationToken)) + await foreach (var segment in _processor!.ProcessAsync(audioSpeechStream, cancellationToken)) { if (cancellationToken.IsCancellationRequested) { @@ -66,13 +54,19 @@ public async IAsyncEnumerable GetStreamingTextAsync( public async Task GetTextAsync(Stream audioSpeechStream, SpeechToTextOptions? options = null, CancellationToken cancellationToken = default) { + if (audioSpeechStream is null) + { + throw new ArgumentNullException(nameof(audioSpeechStream)); + } + SpeechToTextResponse response = new(); - PrepareProcessor(options); + + this._processor ??= options.BuildWhisperProcessor(_factory); StringBuilder fullTranscription = new(); List segments = []; - await foreach (var segment in this._processor!.ProcessAsync(audioSpeechStream, cancellationToken)) + await foreach (var segment in _processor!.ProcessAsync(audioSpeechStream, cancellationToken)) { if (cancellationToken.IsCancellationRequested) { @@ -92,84 +86,4 @@ public async Task GetTextAsync(Stream audioSpeechStream, S return response; } - - private void PrepareProcessor(SpeechToTextOptions? options) - { - if (this._processor is not null) - { - return; - } - - var processorBuilder = this._factory.CreateBuilder(); - if (options is not null) - { - if (!string.IsNullOrWhiteSpace(options?.SpeechLanguage)) - { - processorBuilder.WithLanguage(options!.SpeechLanguage!); - } - - if (GetAdditionalProperty(SpeechToTextOptionsExtensions.AudioContextSizeKey, options!, out var audioContextSize)) - { - processorBuilder.WithAudioContextSize(audioContextSize); - } - - if (GetAdditionalProperty(SpeechToTextOptionsExtensions.BeamSearchSamplingStrategyKey, options!, out var beamSearchSamplingStrategy) && beamSearchSamplingStrategy) - { - processorBuilder.WithBeamSearchSamplingStrategy(); - } - - /* - processorBuilder.WithDuration(options?.Duration ?? TimeSpan.MinValue); - processorBuilder.WithEncoderBeginHandler(options?.EncoderBeginHandler); - processorBuilder.WithEntropyThreshold(options?.EntropyThreshold ?? 0.0f); - - if (GetAdditionalProperty("GreedySamplingStrategy", options!, out var greedySamplingStrategy) && greedySamplingStrategy) - { - processorBuilder.WithGreedySamplingStrategy(); - } - - processorBuilder.WithLanguage(options?.Language ?? string.Empty); - processorBuilder.WithLanguageDetection(options?.LanguageDetection ?? false); - processorBuilder.WithLengthPenalty(options?.LengthPenalty ?? 0.0f); - processorBuilder.WithLogProbThreshold(options?.LogProbThreshold ?? 0.0f); - processorBuilder.WithMaxInitialTs(options?.MaxInitialTs ?? 0); - processorBuilder.WithMaxSegmentLength(options?.MaxSegmentLength ?? 0); - processorBuilder.WithMaxLastTextTokens(options?.MaxLastTextTokens ?? 0); - processorBuilder.WithMaxTokensPerSegment(options?.MaxTokensPerSegment ?? 0); - processorBuilder.WithNoContext(options?.NoContext ?? false); - processorBuilder.WithNoSpeechThreshold(options?.NoSpeechThreshold ?? 0.0f); - processorBuilder.WithOffset(options?.Offset ?? 0); - processorBuilder.WithOpenVinoEncoder(options?.OpenVinoEncoderPath, options?.OpenVinoDevice, options?.OpenVinoCachePath); - processorBuilder.WithoutSuppressBlank(); - processorBuilder.WithoutStringPool(); - processorBuilder.WithPrintProgress(options?.PrintProgress ?? false); - processorBuilder.WithPrintTimestamps(options?.PrintTimestamps ?? false); - processorBuilder.WithPrintSpecialTokens(options?.PrintSpecialTokens ?? false); - processorBuilder.WithPrintResults(options?.PrintResults ?? false); - processorBuilder.WithProbabilities(options?.Probabilities ?? false); - processorBuilder.WithProgressHandler(options?.ProgressHandler); - processorBuilder.WithSegmentEventHandler(options?.SegmentEventHandler); - processorBuilder.WithStringPool(options?.StringPool ?? string.Empty); - processorBuilder.WithTemperature(options?.Temperature ?? 0.0f); - processorBuilder.WithTemperatureInc(options?.TemperatureInc ?? 0.0f); - processorBuilder.WithThreads(options?.Threads ?? 0); - processorBuilder.WithTokenTimestamps(); - processorBuilder.WithTokenTimestampsSumThreshold(options?.TokenTimestampsSumThreshold ?? 0.0f); - processorBuilder.WithTokenTimestampsThreshold(options?.TokenTimestampsThreshold ?? 0.0f); - */ - } - - this._processor = processorBuilder.Build(); - } - - private static bool GetAdditionalProperty(string propertyName, SpeechToTextOptions options, out T? value) - { - if (options.AdditionalProperties?.TryGetValue(propertyName, out value) ?? false) - { - return true; - } - - value = default; - return false; - } } From 896b9923c5795ae2775bdfd294315761a255a7cc Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Wed, 9 Apr 2025 13:15:40 +0100 Subject: [PATCH 4/9] Update to released package --- Directory.Packages.props | 4 ++-- Whisper.net.Demo/Program.cs | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 9bb95877f..776bf2370 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -4,7 +4,7 @@ - + @@ -21,4 +21,4 @@ - \ No newline at end of file + diff --git a/Whisper.net.Demo/Program.cs b/Whisper.net.Demo/Program.cs index 9835d8350..af053d32d 100755 --- a/Whisper.net.Demo/Program.cs +++ b/Whisper.net.Demo/Program.cs @@ -73,7 +73,7 @@ async Task FullDetection(Options opt) using var processor = builder.Build(); using var fileStream = File.OpenRead(opt.FileName); - + Console.WriteLine($"Using {nameof(WhisperProcessor)}:\n"); await foreach (var segment in processor.ProcessAsync(fileStream, CancellationToken.None)) { Console.WriteLine($"New Segment: {segment.Start} ==> {segment.End} : {segment.Text}"); @@ -95,6 +95,7 @@ async Task FullDetectionSpeechToText(Options opt) using var fileStream = File.OpenRead(opt.FileName); + Console.WriteLine($"\nUsing {nameof(ISpeechToTextClient)}:\n"); await foreach (var segment in speechToTextClient.GetStreamingTextAsync(fileStream, options, CancellationToken.None)) { Console.WriteLine($"New Segment: {segment.StartTime} ==> {segment.EndTime} : {segment.Text}"); From 0a579a3c5773f57c09eb1a485e23c1b408046d99 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Wed, 9 Apr 2025 13:24:22 +0100 Subject: [PATCH 5/9] Minor updates --- Whisper.net.Demo/Whisper.net.Demo.csproj | 10 ++++++---- Whisper.net.Demo/Whisper.net.Demo.sln | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 Whisper.net.Demo/Whisper.net.Demo.sln diff --git a/Whisper.net.Demo/Whisper.net.Demo.csproj b/Whisper.net.Demo/Whisper.net.Demo.csproj index 3f6146cd1..df2938cf8 100644 --- a/Whisper.net.Demo/Whisper.net.Demo.csproj +++ b/Whisper.net.Demo/Whisper.net.Demo.csproj @@ -1,7 +1,10 @@  - - - + + + Exe @@ -11,7 +14,6 @@ - diff --git a/Whisper.net.Demo/Whisper.net.Demo.sln b/Whisper.net.Demo/Whisper.net.Demo.sln new file mode 100644 index 000000000..1ac45857a --- /dev/null +++ b/Whisper.net.Demo/Whisper.net.Demo.sln @@ -0,0 +1,24 @@ +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.5.2.0 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Whisper.net.Demo", "Whisper.net.Demo.csproj", "{C98BD896-29FA-3E24-8D52-DE5AE94A47CC}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {C98BD896-29FA-3E24-8D52-DE5AE94A47CC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C98BD896-29FA-3E24-8D52-DE5AE94A47CC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C98BD896-29FA-3E24-8D52-DE5AE94A47CC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C98BD896-29FA-3E24-8D52-DE5AE94A47CC}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {85152ABC-2DDD-454F-865E-8AB9B1C3908E} + EndGlobalSection +EndGlobal From f7ec8113f0edb7aacbd3eeeb947df108dfedd334 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 10 Apr 2025 10:18:28 +0100 Subject: [PATCH 6/9] Add missing Unit Tests --- .../SpeechToTextOptionsExtensionsTests.cs | 759 ++++++++++++++++++ .../WhisperSpeechToTextClientTest.cs | 206 +++++ 2 files changed, 965 insertions(+) create mode 100644 tests/Whisper.net.Tests/SpeechToText/SpeechToTextOptionsExtensionsTests.cs create mode 100644 tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs diff --git a/tests/Whisper.net.Tests/SpeechToText/SpeechToTextOptionsExtensionsTests.cs b/tests/Whisper.net.Tests/SpeechToText/SpeechToTextOptionsExtensionsTests.cs new file mode 100644 index 000000000..05bd05d8f --- /dev/null +++ b/tests/Whisper.net.Tests/SpeechToText/SpeechToTextOptionsExtensionsTests.cs @@ -0,0 +1,759 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + +using Microsoft.Extensions.AI; +using Xunit; + +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace Whisper.net.Tests; + +public class SpeechToTextOptionsExtensionsTests +{ + [Fact] + public void WithLanguage_SetsLanguageProperty() + { + // Arrange + + var options = new SpeechToTextOptions(); + + var language = "en"; + + // Act + var result = options.WithLanguage(language); + + // Assert + Assert.Equal(language, options.SpeechLanguage); + Assert.Same(options, result); + } + + [Fact] + public void WithTranslate_SetsTextLanguageToEnglish() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithTranslate(); + + // Assert + Assert.Equal("English", options.TextLanguage); + Assert.Same(options, result); + } + + [Fact] + public void WithBeamSearchSamplingStrategy_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithBeamSearchSamplingStrategy(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("BeamSearchSamplingStrategy", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithAudioContextSize_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var audioContextSize = 42; + + // Act + var result = options.WithAudioContextSize(audioContextSize); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("AudioContextSize", out var value)); + Assert.Equal(audioContextSize, value); + Assert.Same(options, result); + } + + [Fact] + public void WithDuration_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var duration = TimeSpan.FromSeconds(10); + + // Act + var result = options.WithDuration(duration); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("Duration", out var value)); + Assert.Equal(duration, value); + Assert.Same(options, result); + } + + [Fact] + public void WithEncoderBeginHandler_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + OnEncoderBeginEventHandler handler = (e) => true; + + // Act + var result = options.WithEncoderBeginHandler(handler); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("EncoderBeginHandler", out var value)); + Assert.Same(handler, value); + Assert.Same(options, result); + } + + [Fact] + public void WithEntropyThreshold_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var threshold = 2.4f; + + // Act + var result = options.WithEntropyThreshold(threshold); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("EntropyThreshold", out var value)); + Assert.Equal(threshold, value); + Assert.Same(options, result); + } + + [Fact] + public void WithGreedySamplingStrategy_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithGreedySamplingStrategy(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("GreedySamplingStrategy", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithWhisperLanguage_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var language = "en"; + + // Act + var result = options.WithWhisperLanguage(language); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("Language", out var value)); + Assert.Equal(language, value); + Assert.Same(options, result); + } + + [Fact] + public void WithLanguageDetection_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithLanguageDetection(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("LanguageDetection", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithLanguageDetection_WithFalseParameter_SetsAdditionalPropertyToFalse() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithLanguageDetection(false); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("LanguageDetection", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithLengthPenalty_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var penalty = 0.5f; + + // Act + var result = options.WithLengthPenalty(penalty); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("LengthPenalty", out var value)); + Assert.Equal(penalty, value); + Assert.Same(options, result); + } + + [Fact] + public void WithLogProbThreshold_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var threshold = -1.0f; + + // Act + var result = options.WithLogProbThreshold(threshold); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("LogProbThreshold", out var value)); + Assert.Equal(threshold, value); + Assert.Same(options, result); + } + + [Fact] + public void WithMaxInitialTs_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var maxInitialTs = 1.0f; + + // Act + var result = options.WithMaxInitialTs(maxInitialTs); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("MaxInitialTs", out var value)); + Assert.Equal(maxInitialTs, value); + Assert.Same(options, result); + } + + [Fact] + public void WithMaxSegmentLength_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var maxLength = 100; + + // Act + var result = options.WithMaxSegmentLength(maxLength); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("MaxSegmentLength", out var value)); + Assert.Equal(maxLength, value); + Assert.Same(options, result); + } + + [Fact] + public void WithMaxLastTextTokens_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var maxTokens = 16384; + + // Act + var result = options.WithMaxLastTextTokens(maxTokens); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("MaxLastTextTokens", out var value)); + Assert.Equal(maxTokens, value); + Assert.Same(options, result); + } + + [Fact] + public void WithMaxTokensPerSegment_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var maxTokens = 50; + + // Act + var result = options.WithMaxTokensPerSegment(maxTokens); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("MaxTokensPerSegment", out var value)); + Assert.Equal(maxTokens, value); + Assert.Same(options, result); + } + + [Fact] + public void WithNoContext_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithNoContext(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("NoContext", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithNoContext_WithFalseParameter_SetsAdditionalPropertyToFalse() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithNoContext(false); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("NoContext", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithNoSpeechThreshold_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var threshold = 0.6f; + + // Act + var result = options.WithNoSpeechThreshold(threshold); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("NoSpeechThreshold", out var value)); + Assert.Equal(threshold, value); + Assert.Same(options, result); + } + + [Fact] + public void WithOffset_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var offset = TimeSpan.FromSeconds(5); + + // Act + var result = options.WithOffset(offset); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("Offset", out var value)); + Assert.Equal(offset, value); + Assert.Same(options, result); + } + + [Fact] + public void WithOpenVinoEncoder_SetsAdditionalProperties() + { + // Arrange + var options = new SpeechToTextOptions(); + var encoderPath = "path/to/encoder"; + var device = "CPU"; + var cachePath = "path/to/cache"; + + // Act + var result = options.WithOpenVinoEncoder(encoderPath, device, cachePath); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("OpenVinoEncoderPath", out var pathValue)); + Assert.Equal(encoderPath, pathValue); + Assert.True(options.AdditionalProperties.TryGetValue("OpenVinoDevice", out var deviceValue)); + Assert.Equal(device, deviceValue); + Assert.True(options.AdditionalProperties.TryGetValue("OpenVinoCachePath", out var cacheValue)); + Assert.Equal(cachePath, cacheValue); + Assert.Same(options, result); + } + + [Fact] + public void WithOpenVinoEncoder_WithNullDeviceAndCache_SetsOnlyEncoderPath() + { + // Arrange + var options = new SpeechToTextOptions(); + var encoderPath = "path/to/encoder"; + + // Act + var result = options.WithOpenVinoEncoder(encoderPath); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("OpenVinoEncoderPath", out var pathValue)); + Assert.Equal(encoderPath, pathValue); + Assert.False(options.AdditionalProperties.ContainsKey("OpenVinoDevice")); + Assert.False(options.AdditionalProperties.ContainsKey("OpenVinoCachePath")); + Assert.Same(options, result); + } + + [Fact] + public void WithoutSuppressBlank_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithoutSuppressBlank(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("SuppressBlank", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithoutStringPool_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithoutStringPool(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("StringPool", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintProgress_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintProgress(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintProgress", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintProgress_WithFalseParameter_SetsAdditionalPropertyToFalse() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintProgress(false); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintProgress", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintTimestamps_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintTimestamps(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintTimestamps", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintTimestamps_WithFalseParameter_SetsAdditionalPropertyToFalse() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintTimestamps(false); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintTimestamps", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintSpecialTokens_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintSpecialTokens(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintSpecialTokens", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintSpecialTokens_WithFalseParameter_SetsAdditionalPropertyToFalse() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintSpecialTokens(false); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintSpecialTokens", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintResults_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintResults(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintResults", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithPrintResults_WithFalseParameter_SetsAdditionalPropertyToFalse() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithPrintResults(false); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("PrintResults", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithProbabilities_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithProbabilities(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("Probabilities", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithProbabilities_WithFalseParameter_SetsAdditionalPropertyToFalse() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithProbabilities(false); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("Probabilities", out var value)); + Assert.False((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithProgressHandler_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + OnProgressHandler handler = (progress) => { }; + + // Act + var result = options.WithProgressHandler(handler); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("ProgressHandler", out var value)); + Assert.Same(handler, value); + Assert.Same(options, result); + } + + [Fact] + public void WithSegmentEventHandler_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + OnSegmentEventHandler handler = (e) => { }; + + // Act + var result = options.WithSegmentEventHandler(handler); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("SegmentEventHandler", out var value)); + Assert.Same(handler, value); + Assert.Same(options, result); + } + + [Fact] + public void WithStringPool_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var stringPool = new TestStringPool(); + + // Act + var result = options.WithStringPool(stringPool); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("StringPool", out var value)); + Assert.Same(stringPool, value); + Assert.Same(options, result); + } + + [Fact] + public void WithTemperature_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var temperature = 0.8f; + + // Act + var result = options.WithTemperature(temperature); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("Temperature", out var value)); + Assert.Equal(temperature, value); + Assert.Same(options, result); + } + + [Fact] + public void WithTemperatureInc_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var temperatureInc = 0.2f; + + // Act + var result = options.WithTemperatureInc(temperatureInc); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("TemperatureInc", out var value)); + Assert.Equal(temperatureInc, value); + Assert.Same(options, result); + } + + [Fact] + public void WithThreads_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var threads = 4; + + // Act + var result = options.WithThreads(threads); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("Threads", out var value)); + Assert.Equal(threads, value); + Assert.Same(options, result); + } + + [Fact] + public void WithTokenTimestamps_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + + // Act + var result = options.WithTokenTimestamps(); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("TokenTimestamps", out var value)); + Assert.True((bool)value!); + Assert.Same(options, result); + } + + [Fact] + public void WithTokenTimestampsSumThreshold_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var threshold = 0.01f; + + // Act + var result = options.WithTokenTimestampsSumThreshold(threshold); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("TokenTimestampsSumThreshold", out var value)); + Assert.Equal(threshold, value); + Assert.Same(options, result); + } + + [Fact] + public void WithTokenTimestampsThreshold_SetsAdditionalProperty() + { + // Arrange + var options = new SpeechToTextOptions(); + var threshold = 0.01f; + + // Act + var result = options.WithTokenTimestampsThreshold(threshold); + + // Assert + Assert.NotNull(options.AdditionalProperties); + Assert.True(options.AdditionalProperties.TryGetValue("TokenTimestampsThreshold", out var value)); + Assert.Equal(threshold, value); + Assert.Same(options, result); + } + + // Helper class for testing + private class TestStringPool : IStringPool + { + public string? GetStringUtf8(IntPtr nativeUtf8) + { + return null; + } + + public void ReturnString(string? returnedString) + { + throw new NotImplementedException(); + } + } +} diff --git a/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs new file mode 100644 index 000000000..461add5b3 --- /dev/null +++ b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs @@ -0,0 +1,206 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + +using Microsoft.Extensions.AI; +using Xunit; +using static Whisper.net.Tests.ProcessAsyncFunctionalTests; + +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace Whisper.net.Tests; +public partial class WhisperSpeechToTextClientTest(TinyModelFixture model) : IClassFixture +{ + /* + [Fact] + public async Task TestHappyFlowAsync() + { + var segments = new List(); + var segmentsEnumerated = new List(); + var progress = new List(); + + var encoderBegins = new List(); + using var factory = WhisperFactory.FromPath(model.ModelFile); + using var processor = factory.CreateBuilder() + .WithLanguage("en") + .WithEncoderBeginHandler((e) => + { + encoderBegins.Add(e); + return true; + }) + .WithProgressHandler(progress.Add) + .WithSegmentEventHandler(segments.Add) + .Build(); + + using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + await foreach (var data in processor.ProcessAsync(fileReader)) + { + segmentsEnumerated.Add(data); + } + + Assert.Equal(segments, segmentsEnumerated); + Assert.True(segments.Count > 0); + Assert.True(progress.SequenceEqual(progress.OrderBy(x => x))); + Assert.True(progress.Count > 1); + Assert.Single(encoderBegins); + Assert.Contains(segments, segmentData => segmentData.Text.Contains("nation should commit")); + } + + [Fact] + public async Task ProcessAsync_Cancelled_WillCancellTheProcessing_AndDispose_WillWaitUntilFullyFinished() + { + var segments = new List(); + var segmentsEnumerated = new List(); + var cts = new CancellationTokenSource(); + TaskCanceledException? taskCanceledException = null; + + var encoderBegins = new List(); + using var factory = WhisperFactory.FromPath(model.ModelFile); + var processor = factory.CreateBuilder() + .WithLanguage("en") + .WithEncoderBeginHandler((e) => + { + encoderBegins.Add(e); + return true; + }) + .WithSegmentEventHandler(s => + { + segments.Add(s); + cts.Cancel(); + }) + .Build(); + + using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + try + { + await foreach (var data in processor.ProcessAsync(fileReader, cts.Token)) + { + segmentsEnumerated.Add(data); + } + } + catch (TaskCanceledException ex) + { + taskCanceledException = ex; + } + + await processor.DisposeAsync(); + + Assert.Empty(segmentsEnumerated); + Assert.Single( segments); + Assert.Single( encoderBegins); + Assert.NotNull(taskCanceledException); + Assert.Contains(segments, segmentData => segmentData.Text.Contains("nation should commit")); + } + + [Fact] + public async Task ProcessAsync_WhenJunkChunkExists_ProcessCorrectly() + { + var segments = new List(); + + using var factory = WhisperFactory.FromPath(model.ModelFile); + await using var processor = factory.CreateBuilder() + .WithLanguage("en") + .Build(); + + using var fileReader = await TestDataProvider.OpenFileStreamAsync("junkchunk16khz.wav"); + await foreach (var segment in processor.ProcessAsync(fileReader)) + { + segments.Add(segment); + } + + Assert.True(segments.Count >= 1); + } + + [Fact] + public async Task ProcessAsync_WhenMultichannel_ProcessCorrectly() + { + var segments = new List(); + + using var factory = WhisperFactory.FromPath(model.ModelFile); + await using var processor = factory.CreateBuilder() + .WithLanguage("en") + .Build(); + + using var fileReader = await TestDataProvider.OpenFileStreamAsync("multichannel.wav"); + await foreach (var segment in processor.ProcessAsync(fileReader)) + { + segments.Add(segment); + } + + Assert.True(segments.Count >= 1); + }*/ + + [Fact] + public async Task GetStreamingTextAsync_CalledMultipleTimes_Serially_WillCompleteEverytime() + { + var updates1 = new List(); + var updates2 = new List(); + var updates3 = new List(); + + var client = new WhisperSpeechToTextClient(model.ModelFile); + var options = new SpeechToTextOptions().WithLanguage("en"); + + using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + await foreach (var update in client.GetStreamingTextAsync(fileReader, options)) + { + updates1.Add(update); + } + + using var fileReader2 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + await foreach (var update in client.GetStreamingTextAsync(fileReader2, options)) + { + updates2.Add(update); + } + + using var fileReader3 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + await foreach (var update in client.GetStreamingTextAsync(fileReader3, options)) + { + updates3.Add(update); + } + + Assert.True(updates1.SequenceEqual(updates2, new UpdateDataComparer())); + Assert.True(updates2.SequenceEqual(updates3, new UpdateDataComparer())); + } + + [Fact] + public async Task GetTextAsync_CalledMultipleTimes_Serially_WillCompleteEverytime() + { + var client = new WhisperSpeechToTextClient(model.ModelFile); + var options = new SpeechToTextOptions().WithLanguage("en"); + + using var fileReader1 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + var result1 = await client.GetTextAsync(fileReader1, options); + var segments1 = Assert.IsAssignableFrom>(result1.RawRepresentation); + + using var fileReader2 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + var result2 = await client.GetTextAsync(fileReader2, options); + var segments2 = Assert.IsAssignableFrom>(result2.RawRepresentation); + + using var fileReader3 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + var result3 = await client.GetTextAsync(fileReader3, options); + var segments3 = Assert.IsAssignableFrom>(result3.RawRepresentation); + + + Assert.True(segments1.SequenceEqual(segments2, new SegmentDataComparer())); + Assert.True(segments2.SequenceEqual(segments3, new SegmentDataComparer())); + } + + private class UpdateDataComparer : IEqualityComparer + { + public bool Equals(SpeechToTextResponseUpdate? xUpdate, SpeechToTextResponseUpdate? yUpdate) + { + if (xUpdate == null || yUpdate == null) + { + return false; + } + + var x = (yUpdate.RawRepresentation as SegmentData)!; + var y = (yUpdate.RawRepresentation as SegmentData)!; + + return x.Text == y.Text && x.MinProbability == y.MinProbability && x.Probability == y.Probability && x.Start == y.Start && x.End == y.End; // Compare by relevant properties + } + + public int GetHashCode(SpeechToTextResponseUpdate obj) + { + return obj.Text.GetHashCode(); + } + } +} From 69a8bc5d18162690a917570b4bdc763da5e2df09 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:44:42 +0100 Subject: [PATCH 7/9] Adding factory + process ctor + tests --- .../WhisperSpeechToTextClient.cs | 79 ++++++- ...isperSpeechToTextClientConstructorTests.cs | 193 ++++++++++++++++++ 2 files changed, 263 insertions(+), 9 deletions(-) create mode 100644 tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs diff --git a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs index f322792f3..7fb9ea4bb 100644 --- a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs +++ b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs @@ -8,15 +8,76 @@ namespace Whisper.net; -public sealed class WhisperSpeechToTextClient(string modelFileName) : ISpeechToTextClient +/// +/// Client for speech-to-text operations using Whisper models. +/// +public sealed class WhisperSpeechToTextClient : ISpeechToTextClient { - private readonly WhisperFactory _factory = WhisperFactory.FromPath(modelFileName); - private WhisperProcessor? _processor; + private readonly Func _buildFactoryFunc; + private WhisperFactory? _factory; + private readonly object _factoryLock = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The path to the model file. + public WhisperSpeechToTextClient(string modelFileName) + : this(() => WhisperFactory.FromPath(modelFileName)) + { + } + + /// + /// Initializes a new instance of the class with a factory builder function. + /// + /// A function that creates a WhisperFactory instance. + /// Thrown when the factory builder is null. + public WhisperSpeechToTextClient(Func buildFactoryFunc) + { + if (buildFactoryFunc is null) + { + throw new ArgumentNullException(nameof(buildFactoryFunc)); + } + + _buildFactoryFunc = buildFactoryFunc; + } + + /// + /// Gets the WhisperFactory instance, creating it if it doesn't exist yet. + /// + /// The WhisperFactory instance. + /// Thrown when the factory builder returns null. + private WhisperFactory GetFactory() + { + if (_factory is not null) + { + return _factory; + } + + lock (_factoryLock) + { + if (_factory is not null) + { + return _factory; + } + + _factory = _buildFactoryFunc(); + + if (_factory is null) + { + throw new ArgumentNullException(nameof(_factory)); + } + + return _factory; + } + } public void Dispose() { - _processor?.Dispose(); - _factory?.Dispose(); + lock (_factoryLock) + { + _factory?.Dispose(); + _factory = null; + } } public object? GetService(Type serviceType, object? serviceKey = null) @@ -31,10 +92,10 @@ public async IAsyncEnumerable GetStreamingTextAsync( throw new ArgumentNullException(nameof(audioSpeechStream)); } - this._processor ??= options.BuildWhisperProcessor(_factory); + using var processor = options.BuildWhisperProcessor(GetFactory()); var responseId = Guid.NewGuid().ToString(); - await foreach (var segment in _processor!.ProcessAsync(audioSpeechStream, cancellationToken)) + await foreach (var segment in processor.ProcessAsync(audioSpeechStream, cancellationToken)) { if (cancellationToken.IsCancellationRequested) { @@ -61,12 +122,12 @@ public async Task GetTextAsync(Stream audioSpeechStream, S SpeechToTextResponse response = new(); - this._processor ??= options.BuildWhisperProcessor(_factory); + using var processor = options?.BuildWhisperProcessor(GetFactory()) ?? GetFactory().CreateBuilder().Build(); StringBuilder fullTranscription = new(); List segments = []; - await foreach (var segment in _processor!.ProcessAsync(audioSpeechStream, cancellationToken)) + await foreach (var segment in processor.ProcessAsync(audioSpeechStream, cancellationToken)) { if (cancellationToken.IsCancellationRequested) { diff --git a/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs new file mode 100644 index 000000000..d33d9f225 --- /dev/null +++ b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs @@ -0,0 +1,193 @@ +// Licensed under the MIT license: https://opensource.org/licenses/MIT + +using System.Reflection; +using Microsoft.Extensions.AI; +using Xunit; + +#pragma warning disable MEAI001 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +namespace Whisper.net.Tests; + +public class WhisperSpeechToTextClientConstructorTests : IClassFixture +{ + private readonly TinyModelFixture _model; + + public WhisperSpeechToTextClientConstructorTests(TinyModelFixture model) + { + _model = model; + } + + [Fact] + public void Constructor_WithModelFileName_ShouldNotCreateFactoryImmediately() + { + // Arrange & Act + using var client = new WhisperSpeechToTextClient(_model.ModelFile); + + // Assert + var factoryField = GetFactoryField(client); + Assert.Null(factoryField); + } + + [Fact] + public void Constructor_WithFactoryBuilder_ShouldNotCreateFactoryImmediately() + { + // Arrange + bool factoryBuilderCalled = false; + WhisperFactory FactoryBuilder() + { + factoryBuilderCalled = true; + return WhisperFactory.FromPath(_model.ModelFile); + } + + // Act + using var client = new WhisperSpeechToTextClient(FactoryBuilder); + + // Assert + var factoryField = GetFactoryField(client); + Assert.Null(factoryField); + Assert.False(factoryBuilderCalled, "Factory builder should not be called until needed"); + } + + [Fact] + public void Constructor_WithNullFactoryBuilder_ShouldThrowArgumentNullException() + { + // Arrange & Act & Assert + Assert.Throws(() => new WhisperSpeechToTextClient((Func)null!)); + } + + [Fact] + public async Task GetTextAsync_ShouldCreateFactoryLazily() + { + // Arrange + var factoryBuilderCalled = false; + WhisperFactory FactoryBuilder() + { + factoryBuilderCalled = true; + return WhisperFactory.FromPath(_model.ModelFile); + } + + using var client = new WhisperSpeechToTextClient(FactoryBuilder); + + // Assert - Before + Assert.Null(GetFactoryField(client)); + Assert.False(factoryBuilderCalled, "Factory builder should not be called until needed"); + + // Act + using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + var options = new SpeechToTextOptions().WithLanguage("en"); + await client.GetTextAsync(fileReader, options); + + // Assert - After + Assert.NotNull(GetFactoryField(client)); + Assert.True(factoryBuilderCalled, "Factory builder should be called when GetTextAsync is invoked"); + } + + [Fact] + public async Task GetStreamingTextAsync_ShouldCreateFactoryLazily() + { + // Arrange + var factoryBuilderCalled = false; + WhisperFactory FactoryBuilder() + { + factoryBuilderCalled = true; + return WhisperFactory.FromPath(_model.ModelFile); + } + + using var client = new WhisperSpeechToTextClient(FactoryBuilder); + + // Assert - Before + Assert.Null(GetFactoryField(client)); + Assert.False(factoryBuilderCalled, "Factory builder should not be called until needed"); + + // Act + using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + var options = new SpeechToTextOptions().WithLanguage("en"); + await foreach (var _ in client.GetStreamingTextAsync(fileReader, options)) + { + // Just consume the stream + } + + // Assert - After + Assert.NotNull(GetFactoryField(client)); + Assert.True(factoryBuilderCalled, "Factory builder should be called when GetStreamingTextAsync is invoked"); + } + + [Fact] + public async Task MultipleRequests_ShouldReuseFactory() + { + // Arrange + int factoryBuilderCallCount = 0; + WhisperFactory FactoryBuilder() + { + factoryBuilderCallCount++; + return WhisperFactory.FromPath(_model.ModelFile); + } + + using var client = new WhisperSpeechToTextClient(FactoryBuilder); + var options = new SpeechToTextOptions().WithLanguage("en"); + + // Act - First request + using (var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav")) + { + await client.GetTextAsync(fileReader, options); + } + + var factoryAfterFirstRequest = GetFactoryField(client); + Assert.NotNull(factoryAfterFirstRequest); + Assert.Equal(1, factoryBuilderCallCount); + + // Act - Second request + using (var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav")) + { + await client.GetTextAsync(fileReader, options); + } + + // Assert + var factoryAfterSecondRequest = GetFactoryField(client); + Assert.NotNull(factoryAfterSecondRequest); + Assert.Equal(1, factoryBuilderCallCount); // Factory builder should only be called once + Assert.Same(factoryAfterFirstRequest, factoryAfterSecondRequest); // he same factory instance should be reused + } + + [Fact] + public async Task Dispose_ShouldDisposeFactory() + { + // Arrange + using var client = new WhisperSpeechToTextClient(_model.ModelFile); + var options = new SpeechToTextOptions().WithLanguage("en"); + + // Act - Create factory by using the client + using (var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav")) + { + await client.GetTextAsync(fileReader, options); + } + + var factory = GetFactoryField(client); + Assert.NotNull(factory); + + // Act - Dispose the client + client.Dispose(); + + // Assert + // After disposal, the factory field should be null + Assert.Null(GetFactoryField(client)); + } + + [Fact] + public async Task FactoryBuilder_ReturningNull_ShouldThrowArgumentNullException() + { + // Arrange + using var client = new WhisperSpeechToTextClient(() => null!); + var options = new SpeechToTextOptions().WithLanguage("en"); + + // Act & Assert + using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + await Assert.ThrowsAsync(() => client.GetTextAsync(fileReader, options)); + } + + private static WhisperFactory? GetFactoryField(WhisperSpeechToTextClient client) + { + var fieldInfo = typeof(WhisperSpeechToTextClient).GetField("_factory", BindingFlags.NonPublic | BindingFlags.Instance); + return fieldInfo?.GetValue(client) as WhisperFactory; + } +} From 62c5d17d11b636c0de0e1d4d3700e6dfa1d599b1 Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:56:14 +0100 Subject: [PATCH 8/9] Ensure disposal works as expected --- .../WhisperSpeechToTextClient.cs | 34 ++++++++++- ...isperSpeechToTextClientConstructorTests.cs | 56 ++++++++++++++++++- 2 files changed, 85 insertions(+), 5 deletions(-) diff --git a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs index 7fb9ea4bb..f1bb7cba4 100644 --- a/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs +++ b/Whisper.net/SpeechToTextClient/WhisperSpeechToTextClient.cs @@ -16,6 +16,7 @@ public sealed class WhisperSpeechToTextClient : ISpeechToTextClient private readonly Func _buildFactoryFunc; private WhisperFactory? _factory; private readonly object _factoryLock = new(); + private bool _disposed; /// /// Initializes a new instance of the class. @@ -46,6 +47,7 @@ public WhisperSpeechToTextClient(Func buildFactoryFunc) /// /// The WhisperFactory instance. /// Thrown when the factory builder returns null. + /// Thrown when the client has been disposed. private WhisperFactory GetFactory() { if (_factory is not null) @@ -55,13 +57,18 @@ private WhisperFactory GetFactory() lock (_factoryLock) { + if (_disposed) + { + throw new ObjectDisposedException(nameof(WhisperSpeechToTextClient)); + } + if (_factory is not null) { return _factory; } _factory = _buildFactoryFunc(); - + if (_factory is null) { throw new ArgumentNullException(nameof(_factory)); @@ -73,10 +80,21 @@ private WhisperFactory GetFactory() public void Dispose() { + if (_disposed) + { + return; + } + lock (_factoryLock) { + if (_disposed) + { + return; + } + _factory?.Dispose(); _factory = null; + _disposed = true; } } @@ -87,12 +105,17 @@ public void Dispose() public async IAsyncEnumerable GetStreamingTextAsync(Stream audioSpeechStream, SpeechToTextOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + if (_disposed) + { + throw new ObjectDisposedException(nameof(WhisperSpeechToTextClient)); + } + if (audioSpeechStream is null) { throw new ArgumentNullException(nameof(audioSpeechStream)); } - using var processor = options.BuildWhisperProcessor(GetFactory()); + await using var processor = options.BuildWhisperProcessor(GetFactory()); var responseId = Guid.NewGuid().ToString(); await foreach (var segment in processor.ProcessAsync(audioSpeechStream, cancellationToken)) @@ -115,6 +138,11 @@ public async IAsyncEnumerable GetStreamingTextAsync( public async Task GetTextAsync(Stream audioSpeechStream, SpeechToTextOptions? options = null, CancellationToken cancellationToken = default) { + if (_disposed) + { + throw new ObjectDisposedException(nameof(WhisperSpeechToTextClient)); + } + if (audioSpeechStream is null) { throw new ArgumentNullException(nameof(audioSpeechStream)); @@ -122,7 +150,7 @@ public async Task GetTextAsync(Stream audioSpeechStream, S SpeechToTextResponse response = new(); - using var processor = options?.BuildWhisperProcessor(GetFactory()) ?? GetFactory().CreateBuilder().Build(); + await using var processor = options?.BuildWhisperProcessor(GetFactory()) ?? GetFactory().CreateBuilder().Build(); StringBuilder fullTranscription = new(); List segments = []; diff --git a/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs index d33d9f225..1d3939265 100644 --- a/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs +++ b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientConstructorTests.cs @@ -67,7 +67,7 @@ WhisperFactory FactoryBuilder() } using var client = new WhisperSpeechToTextClient(FactoryBuilder); - + // Assert - Before Assert.Null(GetFactoryField(client)); Assert.False(factoryBuilderCalled, "Factory builder should not be called until needed"); @@ -94,7 +94,7 @@ WhisperFactory FactoryBuilder() } using var client = new WhisperSpeechToTextClient(FactoryBuilder); - + // Assert - Before Assert.Null(GetFactoryField(client)); Assert.False(factoryBuilderCalled, "Factory builder should not be called until needed"); @@ -173,6 +173,58 @@ public async Task Dispose_ShouldDisposeFactory() Assert.Null(GetFactoryField(client)); } + [Fact] + public async Task AfterDispose_GetTextAsync_ShouldThrowObjectDisposedException() + { + // Arrange + var client = new WhisperSpeechToTextClient(_model.ModelFile); + var options = new SpeechToTextOptions().WithLanguage("en"); + + // Initialize the client by using it once + using (var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav")) + { + await client.GetTextAsync(fileReader, options); + } + + // Act - Dispose the client + client.Dispose(); + + // Assert - Attempting to use the client after disposal should throw ObjectDisposedException + using var fileReader2 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + await Assert.ThrowsAsync(() => client.GetTextAsync(fileReader2, options)); + } + + [Fact] + public async Task AfterDispose_GetStreamingTextAsync_ShouldThrowObjectDisposedException() + { + // Arrange + var client = new WhisperSpeechToTextClient(_model.ModelFile); + var options = new SpeechToTextOptions().WithLanguage("en"); + + // Initialize the client by using it once + using (var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav")) + { + await foreach (var _ in client.GetStreamingTextAsync(fileReader, options)) + { + // Just consume the stream + break; // We only need to process one item to initialize the client + } + } + + // Act - Dispose the client + client.Dispose(); + + // Assert - Attempting to use the client after disposal should throw ObjectDisposedException + using var fileReader2 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); + await Assert.ThrowsAsync(async () => + { + await foreach (var _ in client.GetStreamingTextAsync(fileReader2, options)) + { + // This should throw before we get here + } + }); + } + [Fact] public async Task FactoryBuilder_ReturningNull_ShouldThrowArgumentNullException() { From d0e05ff91129f37c154c156aca06e2a3ffae0ccc Mon Sep 17 00:00:00 2001 From: Roger Barreto <19890735+RogerBarreto@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:17:11 +0100 Subject: [PATCH 9/9] Add missing IT --- .../WhisperSpeechToTextClientTest.cs | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs index 461add5b3..3ea9bcb81 100644 --- a/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs +++ b/tests/Whisper.net.Tests/SpeechToText/WhisperSpeechToTextClientTest.cs @@ -9,17 +9,17 @@ namespace Whisper.net.Tests; public partial class WhisperSpeechToTextClientTest(TinyModelFixture model) : IClassFixture { - /* [Fact] public async Task TestHappyFlowAsync() { var segments = new List(); - var segmentsEnumerated = new List(); + var updatesEnumerated = new List(); var progress = new List(); var encoderBegins = new List(); using var factory = WhisperFactory.FromPath(model.ModelFile); - using var processor = factory.CreateBuilder() + + var options = new SpeechToTextOptions() .WithLanguage("en") .WithEncoderBeginHandler((e) => { @@ -27,53 +27,57 @@ public async Task TestHappyFlowAsync() return true; }) .WithProgressHandler(progress.Add) - .WithSegmentEventHandler(segments.Add) - .Build(); + .WithSegmentEventHandler(segments.Add); + + using var client = new WhisperSpeechToTextClient(() => factory); using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); - await foreach (var data in processor.ProcessAsync(fileReader)) + await foreach (var data in client.GetStreamingTextAsync(fileReader, options)) { - segmentsEnumerated.Add(data); + updatesEnumerated.Add(data); } - Assert.Equal(segments, segmentsEnumerated); + Assert.Equal(segments, updatesEnumerated.Select(u => u.RawRepresentation)); Assert.True(segments.Count > 0); Assert.True(progress.SequenceEqual(progress.OrderBy(x => x))); Assert.True(progress.Count > 1); Assert.Single(encoderBegins); Assert.Contains(segments, segmentData => segmentData.Text.Contains("nation should commit")); + Assert.Contains(updatesEnumerated, update => update.Text.Contains("nation should commit")); } [Fact] - public async Task ProcessAsync_Cancelled_WillCancellTheProcessing_AndDispose_WillWaitUntilFullyFinished() + public async Task WithSegmentEventHandler_Cancelled_WillCancellTheProcessing_AndDispose() { var segments = new List(); - var segmentsEnumerated = new List(); + var segmentsEnumerated = new List(); var cts = new CancellationTokenSource(); TaskCanceledException? taskCanceledException = null; var encoderBegins = new List(); using var factory = WhisperFactory.FromPath(model.ModelFile); - var processor = factory.CreateBuilder() - .WithLanguage("en") - .WithEncoderBeginHandler((e) => - { - encoderBegins.Add(e); - return true; - }) - .WithSegmentEventHandler(s => - { - segments.Add(s); - cts.Cancel(); - }) - .Build(); + + var options = new SpeechToTextOptions() + .WithLanguage("en") + .WithEncoderBeginHandler((e) => + { + encoderBegins.Add(e); + return true; + }) + .WithSegmentEventHandler(s => + { + segments.Add(s); + cts.Cancel(); + }); + + var client = new WhisperSpeechToTextClient(() => factory); using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); try { - await foreach (var data in processor.ProcessAsync(fileReader, cts.Token)) + await foreach (var update in client.GetStreamingTextAsync(fileReader, options, cts.Token)) { - segmentsEnumerated.Add(data); + segmentsEnumerated.Add(update); } } catch (TaskCanceledException ex) @@ -81,11 +85,11 @@ public async Task ProcessAsync_Cancelled_WillCancellTheProcessing_AndDispose_Wil taskCanceledException = ex; } - await processor.DisposeAsync(); + client.Dispose(); Assert.Empty(segmentsEnumerated); - Assert.Single( segments); - Assert.Single( encoderBegins); + Assert.Single(segments); + Assert.Single(encoderBegins); Assert.NotNull(taskCanceledException); Assert.Contains(segments, segmentData => segmentData.Text.Contains("nation should commit")); } @@ -93,17 +97,18 @@ public async Task ProcessAsync_Cancelled_WillCancellTheProcessing_AndDispose_Wil [Fact] public async Task ProcessAsync_WhenJunkChunkExists_ProcessCorrectly() { - var segments = new List(); + var segments = new List(); using var factory = WhisperFactory.FromPath(model.ModelFile); - await using var processor = factory.CreateBuilder() - .WithLanguage("en") - .Build(); + var options = new SpeechToTextOptions() + .WithLanguage("en"); + + using var client = new WhisperSpeechToTextClient(() => factory); using var fileReader = await TestDataProvider.OpenFileStreamAsync("junkchunk16khz.wav"); - await foreach (var segment in processor.ProcessAsync(fileReader)) + await foreach (var update in client.GetStreamingTextAsync(fileReader, options)) { - segments.Add(segment); + segments.Add(update); } Assert.True(segments.Count >= 1); @@ -112,21 +117,22 @@ public async Task ProcessAsync_WhenJunkChunkExists_ProcessCorrectly() [Fact] public async Task ProcessAsync_WhenMultichannel_ProcessCorrectly() { - var segments = new List(); + var segments = new List(); using var factory = WhisperFactory.FromPath(model.ModelFile); - await using var processor = factory.CreateBuilder() - .WithLanguage("en") - .Build(); + var options = new SpeechToTextOptions() + .WithLanguage("en"); + + using var client = new WhisperSpeechToTextClient(() => factory); using var fileReader = await TestDataProvider.OpenFileStreamAsync("multichannel.wav"); - await foreach (var segment in processor.ProcessAsync(fileReader)) + await foreach (var update in client.GetStreamingTextAsync(fileReader, options)) { - segments.Add(segment); + segments.Add(update); } Assert.True(segments.Count >= 1); - }*/ + } [Fact] public async Task GetStreamingTextAsync_CalledMultipleTimes_Serially_WillCompleteEverytime() @@ -135,7 +141,7 @@ public async Task GetStreamingTextAsync_CalledMultipleTimes_Serially_WillComplet var updates2 = new List(); var updates3 = new List(); - var client = new WhisperSpeechToTextClient(model.ModelFile); + using var client = new WhisperSpeechToTextClient(model.ModelFile); var options = new SpeechToTextOptions().WithLanguage("en"); using var fileReader = await TestDataProvider.OpenFileStreamAsync("kennedy.wav"); @@ -163,7 +169,7 @@ public async Task GetStreamingTextAsync_CalledMultipleTimes_Serially_WillComplet [Fact] public async Task GetTextAsync_CalledMultipleTimes_Serially_WillCompleteEverytime() { - var client = new WhisperSpeechToTextClient(model.ModelFile); + using var client = new WhisperSpeechToTextClient(model.ModelFile); var options = new SpeechToTextOptions().WithLanguage("en"); using var fileReader1 = await TestDataProvider.OpenFileStreamAsync("kennedy.wav");