diff --git a/examples/c/CMakeLists.txt b/examples/c/CMakeLists.txt index 226e88d994..7cd300cbf7 100644 --- a/examples/c/CMakeLists.txt +++ b/examples/c/CMakeLists.txt @@ -26,6 +26,7 @@ option(MODEL_CHAT "Build the Model Chat example" OFF) option(MODEL_QA "Build the Model Q&A example" OFF) option(MODEL_MM "Build the Model Multimodal example" OFF) option(WHISPER "Build the Whisper example" OFF) +option(NEMOTRON_SPEECH "Build the Nemotron Speech Streaming example" OFF) if(USE_CXX) add_compile_definitions(USE_CXX) @@ -126,3 +127,9 @@ if(WHISPER) target_link_libraries(whisper PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(whisper PRIVATE CLI11::CLI11) endif() + +if(NEMOTRON_SPEECH) + add_executable(nemotron_speech ${EXAMPLES_SOURCE_DIR}/nemotron_speech.cpp) + prepare_executable(nemotron_speech) + target_link_libraries(nemotron_speech PRIVATE nlohmann_json::nlohmann_json) +endif() diff --git a/examples/c/src/nemotron_speech.cpp b/examples/c/src/nemotron_speech.cpp new file mode 100644 index 0000000000..0f4c8f791c --- /dev/null +++ b/examples/c/src/nemotron_speech.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// nemotron_speech.cpp — Streaming ASR example using StreamingProcessor + Generator API. +// +// Usage: +// ./nemotron_speech --model_path /path/to/nemotron-model --audio_file /path/to/audio.wav + +#include +#include +#include +#include +#include +#include + +#include +#include "ort_genai.h" + +struct AudioConfig { + int sample_rate; + int chunk_samples; +}; + +AudioConfig LoadConfig(const std::string& model_path) { + std::string config_path = model_path + "/genai_config.json"; + std::ifstream f(config_path); + if (!f.is_open()) { + throw std::runtime_error("Cannot open " + config_path); + } + auto config = nlohmann::json::parse(f); + return { + config["model"]["sample_rate"].get(), + config["model"]["chunk_samples"].get(), + }; +} + +// Simple WAV loader — expects 16-bit PCM, mono or stereo. +// Returns float32 samples normalized to [-1, 1]. +std::vector LoadWav(const std::string& path, int target_sample_rate) { + std::ifstream file(path, std::ios::binary); + if (!file.is_open()) { + throw std::runtime_error("Cannot open audio file: " + path); + } + + // Read WAV header + char riff[4]; + file.read(riff, 4); + if (std::memcmp(riff, "RIFF", 4) != 0) { + throw std::runtime_error("Not a valid WAV file (missing RIFF header)"); + } + + file.seekg(4, std::ios::cur); // Skip file size + + char wave[4]; + file.read(wave, 4); + if (std::memcmp(wave, "WAVE", 4) != 0) { + throw std::runtime_error("Not a valid WAV file (missing WAVE marker)"); + } + + // Find fmt chunk + int16_t num_channels = 0; + int32_t sample_rate = 0; + int16_t bits_per_sample = 0; + + while (file.good()) { + char chunk_id[4]; + int32_t chunk_size; + file.read(chunk_id, 4); + file.read(reinterpret_cast(&chunk_size), 4); + + if (std::memcmp(chunk_id, "fmt ", 4) == 0) { + int16_t audio_format; + file.read(reinterpret_cast(&audio_format), 2); + file.read(reinterpret_cast(&num_channels), 2); + file.read(reinterpret_cast(&sample_rate), 4); + file.seekg(6, std::ios::cur); // Skip byte rate + block align + file.read(reinterpret_cast(&bits_per_sample), 2); + if (chunk_size > 16) { + file.seekg(chunk_size - 16, std::ios::cur); + } + } else if (std::memcmp(chunk_id, "data", 4) == 0) { + int num_samples = chunk_size / (bits_per_sample / 8) / num_channels; + std::vector audio(num_samples); + + if (bits_per_sample == 16) { + std::vector raw(num_samples * num_channels); + file.read(reinterpret_cast(raw.data()), chunk_size); + for (int i = 0; i < num_samples; i++) { + if (num_channels == 1) { + audio[i] = raw[i] / 32768.0f; + } else { + // Average channels + float sum = 0.0f; + for (int c = 0; c < num_channels; c++) { + sum += raw[i * num_channels + c]; + } + audio[i] = (sum / num_channels) / 32768.0f; + } + } + } else if (bits_per_sample == 32) { + // Assume float32 + std::vector raw(num_samples * num_channels); + file.read(reinterpret_cast(raw.data()), chunk_size); + for (int i = 0; i < num_samples; i++) { + if (num_channels == 1) { + audio[i] = raw[i]; + } else { + float sum = 0.0f; + for (int c = 0; c < num_channels; c++) { + sum += raw[i * num_channels + c]; + } + audio[i] = sum / num_channels; + } + } + } else { + throw std::runtime_error("Unsupported bits per sample: " + std::to_string(bits_per_sample)); + } + + // Basic resampling if needed (linear interpolation) + if (sample_rate != target_sample_rate) { + int new_len = static_cast(audio.size() * static_cast(target_sample_rate) / sample_rate); + std::vector resampled(new_len); + for (int i = 0; i < new_len; i++) { + double src_idx = i * static_cast(audio.size() - 1) / (new_len - 1); + int idx0 = static_cast(src_idx); + int idx1 = std::min(idx0 + 1, static_cast(audio.size()) - 1); + double frac = src_idx - idx0; + resampled[i] = static_cast(audio[idx0] * (1.0 - frac) + audio[idx1] * frac); + } + return resampled; + } + + return audio; + } else { + file.seekg(chunk_size, std::ios::cur); + } + } + + throw std::runtime_error("No data chunk found in WAV file"); +} + +std::string DecodeTokens(OgaGenerator& generator, OgaTokenizerStream& tokenizer_stream) { + std::string text; + while (!generator.IsDone()) { + generator.GenerateNextToken(); + auto next_tokens = generator.GetNextTokens(); + if (!next_tokens.empty()) { + const char* token_text = tokenizer_stream.Decode(next_tokens[0]); + if (token_text && token_text[0] != '\0') { + std::cout << token_text << std::flush; + text += token_text; + } + } + } + return text; +} + +void StreamingTranscribe(const std::string& model_path, const std::string& audio_path) { + auto [sample_rate, chunk_samples] = LoadConfig(model_path); + + std::cout << "Loading audio: " << audio_path << std::endl; + auto audio = LoadWav(audio_path, sample_rate); + double duration = static_cast(audio.size()) / sample_rate; + + std::cout << "Loading model: " << model_path << std::endl; + auto config = OgaConfig::Create(model_path.c_str()); + auto model = OgaModel::Create(*config); + auto processor = OgaStreamingProcessor::Create(*model); + auto tokenizer = OgaTokenizer::Create(*model); + auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); + auto params = OgaGeneratorParams::Create(*model); + auto generator = OgaGenerator::Create(*model, *params); + + std::cout << " Sample rate: " << sample_rate << ", Chunk: " << chunk_samples << " samples" << std::endl; + std::cout << " Audio duration: " << duration << "s" << std::endl; + std::cout << std::string(60, '-') << std::endl; + + auto start = std::chrono::high_resolution_clock::now(); + std::string full_transcript; + + // Stream audio in chunks + for (size_t i = 0; i < audio.size(); i += chunk_samples) { + size_t remaining = std::min(static_cast(chunk_samples), audio.size() - i); + auto inputs = processor->Process(audio.data() + i, remaining); + if (inputs) { + generator->SetInputs(*inputs); + full_transcript += DecodeTokens(*generator, *tokenizer_stream); + } + } + + // Flush remaining audio + { + auto inputs = processor->Flush(); + if (inputs && inputs.get()) { + generator->SetInputs(*inputs); + full_transcript += DecodeTokens(*generator, *tokenizer_stream); + } + } + + auto end = std::chrono::high_resolution_clock::now(); + double wall_time = std::chrono::duration(end - start).count(); + + std::cout << "\n" + << std::string(60, '=') << std::endl; + std::cout << " " << full_transcript << std::endl; + std::cout << std::string(60, '=') << std::endl; + std::cout << " Audio: " << duration << "s | Wall: " << wall_time << "s | RTF: " << (duration / wall_time) << "x" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 3) { + std::cerr << "Usage: " << argv[0] << " --model_path --audio_file " << std::endl; + return 1; + } + + std::string model_path; + std::string audio_file; + + for (int i = 1; i < argc; i++) { + if (std::string(argv[i]) == "--model_path" && i + 1 < argc) { + model_path = argv[++i]; + } else if (std::string(argv[i]) == "--audio_file" && i + 1 < argc) { + audio_file = argv[++i]; + } + } + + if (model_path.empty() || audio_file.empty()) { + std::cerr << "Both --model_path and --audio_file are required." << std::endl; + return 1; + } + + try { + StreamingTranscribe(model_path, audio_file); + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/examples/csharp/NemotronSpeech/NemotronSpeech.csproj b/examples/csharp/NemotronSpeech/NemotronSpeech.csproj new file mode 100644 index 0000000000..e71682a7f1 --- /dev/null +++ b/examples/csharp/NemotronSpeech/NemotronSpeech.csproj @@ -0,0 +1,22 @@ + + + + net8.0 + Exe + enable + enable + true + + + + + + + + + + + + + + diff --git a/examples/csharp/NemotronSpeech/Program.cs b/examples/csharp/NemotronSpeech/Program.cs new file mode 100644 index 0000000000..0a434370ea --- /dev/null +++ b/examples/csharp/NemotronSpeech/Program.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using CommonUtils; +using Microsoft.ML.OnnxRuntimeGenAI; +using NAudio.Wave; +using NAudio.Wave.SampleProviders; +using System.Text.Json; + +if (args.Length < 2) { + Console.WriteLine("Usage: NemotronSpeech [execution_provider]"); + return; +} + +string modelPath = args[0]; +string audioFile = args[1]; +string executionProvider = args.Length > 2 ? args[2] : "follow_config"; + +// Read sample_rate and chunk_samples from genai_config.json +var configJson = JsonDocument.Parse(File.ReadAllText(Path.Combine(modelPath, "genai_config.json"))); +var modelConfig = configJson.RootElement.GetProperty("model"); +int sampleRate = modelConfig.GetProperty("sample_rate").GetInt32(); +int chunkSize = modelConfig.GetProperty("chunk_samples").GetInt32(); + +// Load audio, convert to mono, and resample to match the model's expected sample rate +float[] audio = LoadAudio(audioFile, sampleRate); +Console.WriteLine($"Audio: {audio.Length / (double)sampleRate:F1}s ({audio.Length} samples)"); + +using var config = Common.GetConfig(path: modelPath, ep: executionProvider, null, new GeneratorParamsArgs()); +using var model = new Model(config); +using var processor = new StreamingProcessor(model); +using var tokenizer = new Tokenizer(model); +using var tokenizerStream = tokenizer.CreateStream(); +using var genParams = new GeneratorParams(model); +using var generator = new Generator(model, genParams); +Console.WriteLine(new string('-', 60)); +string fullTranscript = ""; + +for (int i = 0; i < audio.Length; i += chunkSize) { + int remaining = Math.Min(chunkSize, audio.Length - i); + float[] chunk = new float[remaining]; + Array.Copy(audio, i, chunk, 0, remaining); + + using var inputs = processor.Process(chunk); + if (inputs != null) { + generator.SetInputs(inputs); + fullTranscript += DecodeTokens(generator, tokenizerStream); + } +} + +// Flush remaining buffered audio +using var flushInputs = processor.Flush(); +if (flushInputs != null) { + generator.SetInputs(flushInputs); + fullTranscript += DecodeTokens(generator, tokenizerStream); +} + +Console.WriteLine($"\n{new string('=', 60)}"); +Console.WriteLine($" {fullTranscript.Trim()}"); +Console.WriteLine(new string('=', 60)); + +static string DecodeTokens(Generator generator, TokenizerStream tokenizerStream) { + string text = ""; + while (!generator.IsDone()) { + generator.GenerateNextToken(); + var tokens = generator.GetNextTokens(); + if (tokens.Length > 0) { + string tokenText = tokenizerStream.Decode(tokens[0]); + if (!string.IsNullOrEmpty(tokenText)) { + Console.Write(tokenText); + text += tokenText; + } + } + } + return text; +} + +static float[] LoadAudio(string path, int targetSampleRate) { + using var reader = new AudioFileReader(path); + + // Convert to mono if needed + ISampleProvider source = reader; + if (reader.WaveFormat.Channels > 1) { + source = new StereoToMonoSampleProvider(source); + } + + // Resample if needed + if (reader.WaveFormat.SampleRate != targetSampleRate) { + source = new WdlResamplingSampleProvider(source, targetSampleRate); + } + + var samples = new List(); + // Allocate memory to read, any num works. + float[] buffer = new float[4096]; + int read; + while ((read = source.Read(buffer, 0, buffer.Length)) > 0) { + for (int i = 0; i < read; i++) + samples.Add(buffer[i]); + } + return samples.ToArray(); +} diff --git a/examples/csharp/NemotronSpeech/README.md b/examples/csharp/NemotronSpeech/README.md new file mode 100644 index 0000000000..699a1e6c2d --- /dev/null +++ b/examples/csharp/NemotronSpeech/README.md @@ -0,0 +1,60 @@ +# Nemotron Speech Streaming ASR — C# Example + +This example demonstrates real-time streaming speech recognition using the +NVIDIA Nemotron Speech Streaming model with the ONNX Runtime GenAI C# API. + +Audio is streamed through the model in chunks (simulating a microphone feed), +and transcribed text is printed incrementally as it becomes available. + +## Prerequisites + +- .NET 8.0 SDK or later +- ONNX Runtime GenAI C# package ([installation instructions](https://onnxruntime.ai/docs/genai/howto/install)) +- A Nemotron Speech Streaming ONNX model (e.g., `nvidia/nemotron-speech-streaming-en-0.6b`) + +## Build + +```bash +cd examples/csharp/ +dotnet build NemotronSpeech -c Release +``` + +## Run + +```bash +cd ./NemotronSpeech/bin/Release/net8.0/ +./NemotronSpeech [execution_provider] +``` + +### Arguments + +| Argument | Description | +|---|---| +| `model_path` | Path to the Nemotron ONNX model directory | +| `audio_file.wav` | Path to a WAV audio file (any sample rate — resampled automatically) | +| `execution_provider` | *(Optional)* Execution provider: `cpu`, `cuda`, `dml`, or `follow_config` (default: `follow_config`) | + +### Example + +```bash +# CPU inference +./NemotronSpeech /path/to/nemotron-cpu-int4 /path/to/audio.wav + +# CUDA inference +./NemotronSpeech /path/to/nemotron-cuda-int4 /path/to/audio.wav cuda +``` + +### Output + +The example prints transcribed text incrementally as each audio chunk is processed, +followed by a summary with the full transcript, audio duration, wall-clock time, +and real-time factor (RTFx). + +``` +------------------------------------------------------------ + This is an example of streaming speech recognition... +============================================================ + This is an example of streaming speech recognition using Nemotron. +============================================================ + Audio: 10.50s | Wall: 2.13s | RTFx: 4.93x +``` diff --git a/examples/python/nemotron_speech.py b/examples/python/nemotron_speech.py new file mode 100644 index 0000000000..3535f5dd09 --- /dev/null +++ b/examples/python/nemotron_speech.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import json +import os +import sys +import time +import numpy as np +import onnxruntime_genai as og +from common import get_config + + +def load_config(model_path): + """Read sample_rate and chunk_samples from genai_config.json.""" + config_path = os.path.join(model_path, "genai_config.json") + with open(config_path, "r") as f: + config = json.load(f) + sample_rate = config["model"]["sample_rate"] + chunk_samples = config["model"]["chunk_samples"] + return sample_rate, chunk_samples + + +def load_audio(audio_path, sample_rate): + import soundfile as sf + audio, sr = sf.read(audio_path, dtype="float32") + if len(audio.shape) > 1: + audio = audio.mean(axis=1) + if sr != sample_rate: + import scipy.signal + num_samples = int(len(audio) * sample_rate / sr) + audio = scipy.signal.resample(audio, num_samples).astype(np.float32) + return audio + + +def decode_tokens(generator, tokenizer_stream): + """Decode all available tokens from the generator, returning the text.""" + text = "" + while not generator.is_done(): + generator.generate_next_token() + tokens = generator.get_next_tokens() + if len(tokens) > 0: + token_text = tokenizer_stream.decode(tokens[0]) + if token_text: + print(token_text, end="", flush=True) + text += token_text + return text + + +def simulate_microphone(model_path, audio_path, execution_provider): + """Stream audio through Generator + StreamingProcessor API.""" + sample_rate, chunk_samples = load_config(model_path) + audio = load_audio(audio_path, sample_rate) + duration = len(audio) / sample_rate + + config = get_config(model_path, execution_provider) + model = og.Model(config) + processor = og.StreamingProcessor(model) + tokenizer = og.Tokenizer(model) + tokenizer_stream = tokenizer.create_stream() + params = og.GeneratorParams(model) + generator = og.Generator(model, params) + + print("-" * 60) + stream_start = time.time() + full_transcript = "" + + for i in range(0, len(audio), chunk_samples): + chunk = audio[i:i + chunk_samples].astype(np.float32) + inputs = processor.process(chunk) + if inputs is not None: + generator.set_inputs(inputs) + full_transcript += decode_tokens(generator, tokenizer_stream) + + # Flush remaining audio + inputs = processor.flush() + if inputs is not None: + generator.set_inputs(inputs) + full_transcript += decode_tokens(generator, tokenizer_stream) + + total_wall = time.time() - stream_start + + print(f"\n{'=' * 60}") + print(f" {full_transcript.strip()}") + print(f"{'=' * 60}") + print(f" Audio: {duration:.2f}s | Wall: {total_wall:.2f}s | RTF: {duration/total_wall:.2f}x") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--audio_file", type=str, required=True) + parser.add_argument("-e", "--execution_provider", type=str, required=False, default="follow_config", + choices=["cpu", "cuda", "dml", "follow_config"], + help="Execution provider to run with. Defaults to follow_config.") + args = parser.parse_args() + if not os.path.exists(args.audio_file): + print(f"Error: {args.audio_file} not found") + sys.exit(1) + simulate_microphone(args.model_path, args.audio_file, args.execution_provider) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/config.cpp b/src/config.cpp index 8baacca9eb..c5485f08de 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. // Modifications Copyright(C) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. #include "generators.h" +#include "models/model_type.h" #include "runtime_settings.h" #include "json.h" #include @@ -241,6 +242,14 @@ struct EncoderInputs_Element : JSON::Element { v_.position_ids = JSON::Get(value); } else if (name == "audio_features") { v_.audio_features = JSON::Get(value); + } else if (name == "input_lengths") { + v_.input_lengths = JSON::Get(value); + } else if (name == "cache_last_channel") { + v_.cache_last_channel = JSON::Get(value); + } else if (name == "cache_last_time") { + v_.cache_last_time = JSON::Get(value); + } else if (name == "cache_last_channel_len") { + v_.cache_last_channel_len = JSON::Get(value); } else { throw JSON::unknown_value_error{}; } @@ -258,6 +267,14 @@ struct EncoderOutputs_Element : JSON::Element { v_.hidden_states = JSON::Get(value); } else if (name == "encoder_outputs") { v_.encoder_outputs = JSON::Get(value); + } else if (name == "output_lengths") { + v_.output_lengths = JSON::Get(value); + } else if (name == "cache_last_channel_next") { + v_.cache_last_channel_next = JSON::Get(value); + } else if (name == "cache_last_time_next") { + v_.cache_last_time_next = JSON::Get(value); + } else if (name == "cache_last_channel_len_next") { + v_.cache_last_channel_len_next = JSON::Get(value); } else if (name == "cross_present_key_names") { v_.cross_present_key_names = JSON::Get(value); } else if (name == "cross_present_value_names") { @@ -315,6 +332,14 @@ struct DecoderInputs_Element : JSON::Element { v_.past_sequence_lengths = JSON::Get(value); } else if (name == "block_table") { v_.block_table = JSON::Get(value); + } else if (name == "targets") { + v_.targets = JSON::Get(value); + } else if (name == "target_length") { + v_.target_length = JSON::Get(value); + } else if (name == "lstm_hidden_state") { + v_.lstm_hidden_state = JSON::Get(value); + } else if (name == "lstm_cell_state") { + v_.lstm_cell_state = JSON::Get(value); } else { throw JSON::unknown_value_error{}; } @@ -340,6 +365,14 @@ struct DecoderOutputs_Element : JSON::Element { v_.output_cross_qk_names = JSON::Get(value); } else if (name == "rnn_states") { v_.rnn_states = JSON::Get(value); + } else if (name == "outputs") { + v_.outputs = JSON::Get(value); + } else if (name == "prednet_lengths") { + v_.prednet_lengths = JSON::Get(value); + } else if (name == "lstm_hidden_state") { + v_.lstm_hidden_state = JSON::Get(value); + } else if (name == "lstm_cell_state") { + v_.lstm_cell_state = JSON::Get(value); } else { throw JSON::unknown_value_error{}; } @@ -557,10 +590,10 @@ struct Decoder_Element : JSON::Element { v_.hidden_size = static_cast(JSON::Get(value)); } else if (name == "num_attention_heads") { v_.num_attention_heads = static_cast(JSON::Get(value)); - } else if (name == "num_key_value_heads") { - v_.num_key_value_heads = static_cast(JSON::Get(value)); } else if (name == "num_hidden_layers") { v_.num_hidden_layers = static_cast(JSON::Get(value)); + } else if (name == "num_key_value_heads") { + v_.num_key_value_heads = static_cast(JSON::Get(value)); } else if (name == "head_size") { v_.head_size = static_cast(JSON::Get(value)); } else { @@ -862,6 +895,77 @@ struct Speech_Element : JSON::Element { SpeechOutputs_Element outputs_{v_.outputs}; }; +struct JoinerInputs_Element : JSON::Element { + explicit JoinerInputs_Element(Config::Model::Joiner::Inputs& v) : v_{v} {} + + void OnValue(std::string_view name, JSON::Value value) override { + if (name == "encoder_outputs") { + v_.encoder_outputs = JSON::Get(value); + } else if (name == "decoder_outputs") { + v_.decoder_outputs = JSON::Get(value); + } else { + throw JSON::unknown_value_error{}; + } + } + + private: + Config::Model::Joiner::Inputs& v_; +}; + +struct JoinerOutputs_Element : JSON::Element { + explicit JoinerOutputs_Element(Config::Model::Joiner::Outputs& v) : v_{v} {} + + void OnValue(std::string_view name, JSON::Value value) override { + if (name == "logits") { + v_.logits = JSON::Get(value); + } else { + throw JSON::unknown_value_error{}; + } + } + + private: + Config::Model::Joiner::Outputs& v_; +}; + +struct Joiner_Element : JSON::Element { + explicit Joiner_Element(Config::Model::Joiner& v) : v_{v} {} + + void OnValue(std::string_view name, JSON::Value value) override { + if (name == "filename") { + v_.filename = JSON::Get(value); + } else { + throw JSON::unknown_value_error{}; + } + } + + Element& OnObject(std::string_view name) override { + if (name == "session_options") { + v_.session_options = Config::SessionOptions{}; + session_options_ = std::make_unique(*v_.session_options); + return *session_options_; + } + if (name == "run_options") { + v_.run_options = Config::RunOptions{}; + run_options_ = std::make_unique(*v_.run_options); + return *run_options_; + } + if (name == "inputs") { + return inputs_; + } + if (name == "outputs") { + return outputs_; + } + throw JSON::unknown_value_error{}; + } + + private: + Config::Model::Joiner& v_; + std::unique_ptr session_options_; + std::unique_ptr run_options_; + JoinerInputs_Element inputs_{v_.inputs}; + JoinerOutputs_Element outputs_{v_.outputs}; +}; + struct EmbeddingInputs_Element : JSON::Element { explicit EmbeddingInputs_Element(Config::Model::Embedding::Inputs& v) : v_{v} {} @@ -961,6 +1065,34 @@ struct Model_Element : JSON::Element { v_.video_token_id = static_cast(JSON::Get(value)); } else if (name == "vision_start_token_id") { v_.vision_start_token_id = static_cast(JSON::Get(value)); + } else if (name == "num_mels") { + v_.num_mels = static_cast(JSON::Get(value)); + } else if (name == "fft_size") { + v_.fft_size = static_cast(JSON::Get(value)); + } else if (name == "hop_length") { + v_.hop_length = static_cast(JSON::Get(value)); + } else if (name == "win_length") { + v_.win_length = static_cast(JSON::Get(value)); + } else if (name == "preemph") { + v_.preemph = static_cast(JSON::Get(value)); + } else if (name == "log_eps") { + v_.log_eps = static_cast(JSON::Get(value)); + } else if (name == "subsampling_factor") { + v_.subsampling_factor = static_cast(JSON::Get(value)); + } else if (name == "left_context") { + v_.left_context = static_cast(JSON::Get(value)); + } else if (name == "conv_context") { + v_.conv_context = static_cast(JSON::Get(value)); + } else if (name == "pre_encode_cache_size") { + v_.pre_encode_cache_size = static_cast(JSON::Get(value)); + } else if (name == "sample_rate") { + v_.sample_rate = static_cast(JSON::Get(value)); + } else if (name == "chunk_samples") { + v_.chunk_samples = static_cast(JSON::Get(value)); + } else if (name == "blank_id") { + v_.blank_id = static_cast(JSON::Get(value)); + } else if (name == "max_symbols_per_step") { + v_.max_symbols_per_step = static_cast(JSON::Get(value)); } else { throw JSON::unknown_value_error{}; } @@ -988,6 +1120,9 @@ struct Model_Element : JSON::Element { if (name == "speech") { return speech_; } + if (name == "joiner") { + return joiner_; + } throw JSON::unknown_value_error{}; } @@ -999,6 +1134,7 @@ struct Model_Element : JSON::Element { Vision_Element vision_{v_.vision}; Embedding_Element embedding_{v_.embedding}; Speech_Element speech_{v_.speech}; + Joiner_Element joiner_{v_.joiner}; }; int SafeDoubleToInt(double x, std::string_view name) { @@ -1387,7 +1523,7 @@ void OverlayConfig(Config& config, std::string_view json) { Config::Config(const fs::path& path, std::string_view json_overlay) : config_path{path} { ParseConfig(path / "genai_config.json", json_overlay, *this); - if (model.context_length == 0) { + if (model.context_length == 0 && !ModelType::IsRNNT(model.type)) { throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0"); } diff --git a/src/config.h b/src/config.h index bdf7c64160..4fe52b9059 100644 --- a/src/config.h +++ b/src/config.h @@ -60,6 +60,25 @@ struct Config { static constexpr std::string_view EncoderHiddenStatesName = "encoder_hidden_states"; static constexpr std::string_view EncoderOutputsName = "encoder_outputs"; static constexpr std::string_view EncoderAttentionMaskName = "encoder_attention_mask"; + + // Cache-aware streaming encoder names + static constexpr std::string_view EncoderInputLengthsName = "length"; + static constexpr std::string_view CacheLastChannelName = "cache_last_channel"; + static constexpr std::string_view CacheLastTimeName = "cache_last_time"; + static constexpr std::string_view CacheLastChannelLenName = "cache_last_channel_len"; + static constexpr std::string_view EncoderOutputLengthsName = "encoded_lengths"; + static constexpr std::string_view CacheLastChannelNextName = "cache_last_channel_next"; + static constexpr std::string_view CacheLastTimeNextName = "cache_last_time_next"; + static constexpr std::string_view CacheLastChannelLenNextName = "cache_last_channel_len_next"; + + // Cross present key/value names + static constexpr std::string_view CrossPresentKeyName = "present_key_cross_%d"; + static constexpr std::string_view CrossPresentValueName = "present_value_cross_%d"; + + // Joiner names + static constexpr std::string_view JoinerEncoderOutputsName = "encoder_outputs"; + static constexpr std::string_view JoinerDecoderOutputsName = "decoder_outputs"; + static constexpr std::string_view JoinerLogitsName = "outputs"; }; fs::path config_path; // Path of the config directory @@ -116,6 +135,22 @@ struct Config { int vocab_size{}; int context_length{}; + // Streaming ASR / RNNT model parameters + int num_mels{}; + int fft_size{}; + int hop_length{}; + int win_length{}; + float preemph{}; + float log_eps{}; + int subsampling_factor{}; + int left_context{}; + int conv_context{}; + int pre_encode_cache_size{}; + int sample_rate{}; + int chunk_samples{}; + int blank_id{}; + int max_symbols_per_step{}; + struct Encoder { std::string filename; std::optional session_options; @@ -133,12 +168,22 @@ struct Config { std::string attention_mask{Defaults::AttentionMaskName}; std::string position_ids{Defaults::PositionIdsName}; std::string audio_features{Defaults::AudioFeaturesName}; + // Cache-aware streaming encoder I/O names + std::string input_lengths{Defaults::EncoderInputLengthsName}; + std::string cache_last_channel{Defaults::CacheLastChannelName}; + std::string cache_last_time{Defaults::CacheLastTimeName}; + std::string cache_last_channel_len{Defaults::CacheLastChannelLenName}; } inputs; struct Outputs { std::string encoder_outputs{Defaults::EncoderOutputsName}; std::string hidden_states{Defaults::EncoderHiddenStatesName}; - std::string cross_present_key_names{"present_key_cross_%d"}, cross_present_value_names{"present_value_cross_%d"}; + std::string cross_present_key_names{Defaults::CrossPresentKeyName}, cross_present_value_names{Defaults::CrossPresentValueName}; + // Cache-aware streaming encoder output names + std::string output_lengths{Defaults::EncoderOutputLengthsName}; + std::string cache_last_channel_next{Defaults::CacheLastChannelNextName}; + std::string cache_last_time_next{Defaults::CacheLastTimeNextName}; + std::string cache_last_channel_len_next{Defaults::CacheLastChannelLenNextName}; } outputs; } encoder; @@ -215,6 +260,21 @@ struct Config { } outputs; } speech; + struct Joiner { + std::string filename; + std::optional session_options; + std::optional run_options; + + struct Inputs { + std::string encoder_outputs{Defaults::JoinerEncoderOutputsName}; + std::string decoder_outputs{Defaults::JoinerDecoderOutputsName}; + } inputs; + + struct Outputs { + std::string logits{Defaults::JoinerLogitsName}; + } outputs; + } joiner; + struct Decoder { std::string filename; SessionOptions session_options; @@ -256,6 +316,12 @@ struct Config { std::string cumulative_sequence_lengths{Defaults::CumulativeSequenceLengthsName}; std::string past_sequence_lengths{Defaults::PastSequenceLengthsName}; std::string block_table{Defaults::BlockTableName}; + + // RNNT decoder inputs + std::string targets; + std::string target_length; + std::string lstm_hidden_state; + std::string lstm_cell_state; } inputs; struct Outputs { @@ -265,6 +331,12 @@ struct Config { std::string present_names; // When key/value pairs are combined std::string output_cross_qk_names{"output_cross_qk_%d"}; std::string rnn_states{Defaults::RnnStatesName}; + + // RNNT decoder outputs + std::string outputs; + std::string prednet_lengths; + std::string lstm_hidden_state; + std::string lstm_cell_state; } outputs; struct PipelineModel { diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index a1121c672c..fd773bb1f5 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -440,5 +440,23 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] public static extern IntPtr /* OgaResult* */ OgaUnloadAdapter(IntPtr /* OgaAdapters* */ adapters, byte[] /* const char* */ adapterName); + + // StreamingProcessor API + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaCreateStreamingProcessor(IntPtr /* const OgaModel* */ model, + out IntPtr /* OgaStreamingProcessor** */ processor); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern void OgaDestroyStreamingProcessor(IntPtr /* OgaStreamingProcessor* */ processor); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern unsafe IntPtr /* OgaResult* */ OgaStreamingProcessorProcess(IntPtr /* OgaStreamingProcessor* */ processor, + float* /* const float* */ audioData, + UIntPtr /* size_t */ numSamples, + out IntPtr /* OgaNamedTensors** */ out_named_tensors); + + [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] + public static extern IntPtr /* OgaResult* */ OgaStreamingProcessorFlush(IntPtr /* OgaStreamingProcessor* */ processor, + out IntPtr /* OgaNamedTensors** */ out_named_tensors); } } diff --git a/src/csharp/StreamingProcessor.cs b/src/csharp/StreamingProcessor.cs new file mode 100644 index 0000000000..7f4675d231 --- /dev/null +++ b/src/csharp/StreamingProcessor.cs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; + +namespace Microsoft.ML.OnnxRuntimeGenAI +{ + public class StreamingProcessor : IDisposable + { + private IntPtr _processorHandle; + private bool _disposed = false; + + public StreamingProcessor(Model model) + { + Result.VerifySuccess(NativeMethods.OgaCreateStreamingProcessor(model.Handle, out _processorHandle)); + } + + internal IntPtr Handle { get { return _processorHandle; } } + + /// + /// Feed a chunk of raw PCM audio (mono, float32, 16kHz). + /// Returns a NamedTensors if a full chunk is ready, or null if more audio is needed. + /// + public NamedTensors? Process(float[] audioData) + { + IntPtr outHandle = IntPtr.Zero; + unsafe + { + fixed (float* audioPtr = audioData) + { + Result.VerifySuccess(NativeMethods.OgaStreamingProcessorProcess( + _processorHandle, audioPtr, (UIntPtr)audioData.Length, out outHandle)); + } + } + return outHandle != IntPtr.Zero ? new NamedTensors(outHandle) : null; + } + + /// + /// Flush remaining buffered audio (pads with silence). + /// Returns a NamedTensors or null if the buffer was empty. + /// + public NamedTensors? Flush() + { + IntPtr outHandle = IntPtr.Zero; + Result.VerifySuccess(NativeMethods.OgaStreamingProcessorFlush(_processorHandle, out outHandle)); + return outHandle != IntPtr.Zero ? new NamedTensors(outHandle) : null; + } + + ~StreamingProcessor() + { + Dispose(false); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + NativeMethods.OgaDestroyStreamingProcessor(_processorHandle); + _processorHandle = IntPtr.Zero; + _disposed = true; + } + } +} diff --git a/src/generators.cpp b/src/generators.cpp index 67b7f34130..bb2d7efbb9 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -4,9 +4,12 @@ // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. #include "generators.h" +#include "models/streaming_processor.h" +#include "models/nemotron_speech.h" #include "sequences.h" #include "models/env_utils.h" #include "models/model.h" +#include "models/model_type.h" #include "models/decoder_only.h" #include "constrained_logits_processor.h" #include "search.h" @@ -345,6 +348,13 @@ std::unique_ptr CreateSearch(const GeneratorParams& params) { } Generator::Generator(const Model& model, const GeneratorParams& params) : model_{model.shared_from_this()} { + // RNNT models don't use the traditional search/logits pipeline, + // so skip the standard validations and just create the state. + if (ModelType::IsRNNT(model.config_->model.type)) { + state_ = model.CreateState({}, params); + return; + } + if (params.search.max_length == 0) throw std::runtime_error("search max_length is 0"); if (params.search.max_length > model.config_->model.context_length) @@ -504,11 +514,20 @@ void Generator::SetRuntimeOption(const char* key, const char* value) { } size_t Generator::TokenCount() const { + if (auto* speech_state = dynamic_cast(state_.get())) + return speech_state->TokenCount(); return static_cast(search_->GetSequenceLength()); } bool Generator::IsDone() { ThrowErrorIfSessionTerminated(state_->session_terminated_); + + if (auto* speech_state = dynamic_cast(state_.get())) { + // Pending mel input means we haven't started processing this chunk yet + if (!extra_inputs_.empty()) return false; + return speech_state->IsChunkDone(); + } + if (computed_logits_) { return false; } @@ -538,6 +557,15 @@ void Generator::GenerateNextToken() { DurationTrace trace{"Generator::GenerateNextToken"}; ThrowErrorIfSessionTerminated(state_->session_terminated_); + + // RNNT models: yield one token per call from the decoder state machine + if (auto* speech_state = dynamic_cast(state_.get())) { + state_->SetExtraInputs(extra_inputs_); + extra_inputs_.clear(); + speech_state->StepToken(); + return; + } + if (search_->GetSequenceLength() == 0 && !computed_logits_) throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or SetInputs before calling GenerateNextToken."); diff --git a/src/leakcheck.h b/src/leakcheck.h index 54abba29fc..8be2652454 100644 --- a/src/leakcheck.h +++ b/src/leakcheck.h @@ -14,6 +14,7 @@ struct Generator; struct Model; struct Request; struct Search; +struct StreamingProcessor; struct Tensor; struct Tokenizer; struct TokenizerStream; @@ -25,7 +26,7 @@ struct LeakTypeList { static bool Dump(); }; -using LeakTypes = LeakTypeList; +using LeakTypes = LeakTypeList; template struct LeakChecked { diff --git a/src/models/model.cpp b/src/models/model.cpp index 5c4477fc00..c28f2bc4a4 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -16,6 +16,7 @@ #include "gpt.h" #include "decoder_only.h" #include "whisper.h" +#include "nemotron_speech.h" #include "multi_modal.h" #include "marian.h" #include "decoder_only_pipeline.h" @@ -1278,6 +1279,8 @@ std::shared_ptr CreateModel(OrtEnv& ort_env, std::unique_ptr conf return std::make_shared(std::move(config), ort_env); if (ModelType::IsLLM(config->model.type)) return std::make_shared(std::move(config), ort_env); + if (ModelType::IsRNNT(config->model.type)) + return std::make_shared(std::move(config), ort_env); if (ModelType::IsALM(config->model.type)) return std::make_shared(std::move(config), ort_env); if (ModelType::IsVLM(config->model.type)) diff --git a/src/models/model_type.h b/src/models/model_type.h index c2ee97b7cb..d2c4468d61 100644 --- a/src/models/model_type.h +++ b/src/models/model_type.h @@ -36,6 +36,12 @@ struct ModelType { return std::find(ALM.begin(), ALM.end(), model_type) != ALM.end(); } + inline static bool IsRNNT(const std::string& model_type) { + // RNNT models bypass the search/logits pipeline entirely. + static constexpr std::array rnnt_types = {"nemotron_speech"}; + return std::find(rnnt_types.begin(), rnnt_types.end(), model_type) != rnnt_types.end(); + } + inline static bool IsMMM(const std::string& model_type) { // Multi-modal model (MMM) static constexpr std::array MMM = {"phi4mm"}; diff --git a/src/models/nemotron_speech.cpp b/src/models/nemotron_speech.cpp new file mode 100644 index 0000000000..003535069d --- /dev/null +++ b/src/models/nemotron_speech.cpp @@ -0,0 +1,535 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "../generators.h" +#include "nemo_mel_spectrogram.h" +#include "nemotron_speech.h" + +namespace Generators { + +void NemotronCacheConfig::PopulateFromConfig(const Config& config) { + const auto& enc = config.model.encoder; + const auto& dec = config.model.decoder; + const auto& jo = config.model.joiner; + + // Encoder dimensions + hidden_dim = enc.hidden_size; + num_encoder_layers = enc.num_hidden_layers; + + // Decoder dimensions (LSTM) + decoder_lstm_dim = dec.hidden_size; + decoder_lstm_layers = dec.num_hidden_layers; + + // Speech / mel feature config (now at model level) + num_mels = config.model.num_mels; + fft_size = config.model.fft_size; + hop_length = config.model.hop_length; + win_length = config.model.win_length; + preemph = config.model.preemph; + log_eps = config.model.log_eps; + subsampling_factor = config.model.subsampling_factor; + left_context = config.model.left_context; + conv_context = config.model.conv_context; + pre_encode_cache_size = config.model.pre_encode_cache_size; + sample_rate = config.model.sample_rate; + chunk_samples = config.model.chunk_samples; + blank_id = config.model.blank_id; + max_symbols_per_step = config.model.max_symbols_per_step; + + // Vocab size from top-level config + vocab_size = config.model.vocab_size; + + // Encoder I/O names + enc_in_audio = enc.inputs.audio_features; + enc_out_encoded = enc.outputs.encoder_outputs; + + // Encoder cache I/O names (from encoder inputs/outputs) + enc_in_length = enc.inputs.input_lengths; + enc_in_cache_channel = enc.inputs.cache_last_channel; + enc_in_cache_time = enc.inputs.cache_last_time; + enc_in_cache_channel_len = enc.inputs.cache_last_channel_len; + enc_out_length = enc.outputs.output_lengths; + enc_out_cache_channel = enc.outputs.cache_last_channel_next; + enc_out_cache_time = enc.outputs.cache_last_time_next; + enc_out_cache_channel_len = enc.outputs.cache_last_channel_len_next; + + // Joiner I/O names + join_in_encoder = jo.inputs.encoder_outputs; + join_in_decoder = jo.inputs.decoder_outputs; + join_out_logits = jo.outputs.logits; + + // Decoder I/O names (RNNT prediction network) + dec_in_targets = dec.inputs.targets; + dec_in_target_length = dec.inputs.target_length; + dec_in_lstm_hidden = dec.inputs.lstm_hidden_state; + dec_in_lstm_cell = dec.inputs.lstm_cell_state; + dec_out_outputs = dec.outputs.outputs; + dec_out_prednet_lengths = dec.outputs.prednet_lengths; + dec_out_lstm_hidden = dec.outputs.lstm_hidden_state; + dec_out_lstm_cell = dec.outputs.lstm_cell_state; +} + +void NemotronEncoderCache::Initialize(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device) { + auto cache_channel_type = session_info.GetInputDataType(cfg.enc_in_cache_channel); + auto cache_time_type = session_info.GetInputDataType(cfg.enc_in_cache_time); + auto cache_channel_len_type = session_info.GetInputDataType(cfg.enc_in_cache_channel_len); + + // cache_last_channel: [batch, num_layers, left_context, hidden_dim] + auto ch_shape = std::array{1, cfg.num_encoder_layers, cfg.left_context, cfg.hidden_dim}; + cache_last_channel = OrtValue::CreateTensor(allocator, ch_shape, cache_channel_type); + ByteWrapTensor(*GetDeviceInterface(DeviceType::CPU), *cache_last_channel).Zero(); + + // cache_last_time: [batch, num_layers, hidden_dim, conv_context] + auto tm_shape = std::array{1, cfg.num_encoder_layers, cfg.hidden_dim, cfg.conv_context}; + cache_last_time = OrtValue::CreateTensor(allocator, tm_shape, cache_time_type); + ByteWrapTensor(*GetDeviceInterface(DeviceType::CPU), *cache_last_time).Zero(); + + // cache_last_channel_len: [1] + auto len_shape = std::array{1}; + cache_last_channel_len = OrtValue::CreateTensor(allocator, len_shape, cache_channel_len_type); + *cache_last_channel_len->GetTensorMutableData() = 0; +} + +void NemotronEncoderCache::Reset(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device) { + Initialize(cfg, session_info, allocator, device); +} + +void NemotronDecoderState::Initialize(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device) { + auto lstm_hidden_type = session_info.GetInputDataType(cfg.dec_in_lstm_hidden); + auto lstm_cell_type = session_info.GetInputDataType(cfg.dec_in_lstm_cell); + + // LSTM states: [lstm_layers, 1, lstm_dim] + auto state_shape = std::array{cfg.decoder_lstm_layers, 1, cfg.decoder_lstm_dim}; + lstm_hidden_state = OrtValue::CreateTensor(allocator, state_shape, lstm_hidden_type); + ByteWrapTensor(*GetDeviceInterface(DeviceType::CPU), *lstm_hidden_state).Zero(); + + lstm_cell_state = OrtValue::CreateTensor(allocator, state_shape, lstm_cell_type); + ByteWrapTensor(*GetDeviceInterface(DeviceType::CPU), *lstm_cell_state).Zero(); + + last_token = cfg.blank_id; // Start with blank/SOS token +} + +void NemotronDecoderState::Reset(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device) { + Initialize(cfg, session_info, allocator, device); +} + +NemotronSpeechModel::NemotronSpeechModel(std::unique_ptr config, OrtEnv& ort_env) + : Model{std::move(config)} { + cache_config_ = NemotronCacheConfig{}; + cache_config_.PopulateFromConfig(*config_); + + // Create session options + encoder_session_options_ = OrtSessionOptions::Create(); + decoder_session_options_ = OrtSessionOptions::Create(); + joiner_session_options_ = OrtSessionOptions::Create(); + + if (config_->model.encoder.session_options.has_value()) { + CreateSessionOptionsFromConfig(config_->model.encoder.session_options.value(), + *encoder_session_options_, true, false); + } else { + CreateSessionOptionsFromConfig(config_->model.decoder.session_options, + *encoder_session_options_, true, false); + } + CreateSessionOptionsFromConfig(config_->model.decoder.session_options, + *decoder_session_options_, true, false); + if (config_->model.joiner.session_options.has_value()) { + CreateSessionOptionsFromConfig(config_->model.joiner.session_options.value(), + *joiner_session_options_, true, false); + } else { + CreateSessionOptionsFromConfig(config_->model.decoder.session_options, + *joiner_session_options_, true, false); + } + + // Load the three ONNX models + std::string encoder_filename = config_->model.encoder.filename; + if (encoder_filename.empty()) encoder_filename = "encoder.onnx"; + + std::string decoder_filename = config_->model.decoder.filename; + if (decoder_filename.empty()) decoder_filename = "decoder.onnx"; + + std::string joiner_filename = config_->model.joiner.filename; + if (joiner_filename.empty()) joiner_filename = "joiner.onnx"; + + session_encoder_ = CreateSession(ort_env, encoder_filename, encoder_session_options_.get()); + session_decoder_ = CreateSession(ort_env, decoder_filename, decoder_session_options_.get()); + session_joiner_ = CreateSession(ort_env, joiner_filename, joiner_session_options_.get()); + + session_info_.Add(*session_encoder_); + session_info_.Add(*session_decoder_); + session_info_.Add(*session_joiner_); +} + +std::unique_ptr NemotronSpeechModel::CreateState(DeviceSpan /*sequence_lengths*/, + const GeneratorParams& params) const { + return std::make_unique(*this, params); +} + +NemotronEncoderSubState::NemotronEncoderSubState(const NemotronSpeechModel& model, const GeneratorParams& params) + : State{params, model}, + model_{model} { + auto& cfg = model_.cache_config_; + auto& allocator = model_.allocator_cpu_; + auto& device = *model_.p_device_; + + cache_.Initialize(cfg, model_.session_info_, allocator, device); + + // Create signal_length tensor + auto len_type = model_.session_info_.GetInputDataType(cfg.enc_in_length); + auto len_shape = std::array{1}; + signal_length_ = OrtValue::CreateTensor(allocator, len_shape, len_type); + + // Register inputs: mel, length, cache_channel, cache_time, cache_channel_len + mel_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.enc_in_audio.c_str()); + inputs_.push_back(nullptr); + + length_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.enc_in_length.c_str()); + inputs_.push_back(signal_length_.get()); + + cache_channel_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.enc_in_cache_channel.c_str()); + inputs_.push_back(cache_.cache_last_channel.get()); + + cache_time_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.enc_in_cache_time.c_str()); + inputs_.push_back(cache_.cache_last_time.get()); + + cache_channel_len_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.enc_in_cache_channel_len.c_str()); + inputs_.push_back(cache_.cache_last_channel_len.get()); + + // Register outputs: encoded, length, cache_channel_next, cache_time_next, cache_channel_len_next + output_names_.push_back(cfg.enc_out_encoded.c_str()); + outputs_.push_back(nullptr); + + output_names_.push_back(cfg.enc_out_length.c_str()); + outputs_.push_back(nullptr); + + output_names_.push_back(cfg.enc_out_cache_channel.c_str()); + outputs_.push_back(nullptr); + + output_names_.push_back(cfg.enc_out_cache_time.c_str()); + outputs_.push_back(nullptr); + + output_names_.push_back(cfg.enc_out_cache_channel_len.c_str()); + outputs_.push_back(nullptr); + + // Set run options from config + if (model_.config_->model.encoder.run_options.has_value()) { + State::SetRunOptions(model_.config_->model.encoder.run_options.value()); + } +} + +void NemotronEncoderSubState::SetMelInput(OrtValue* mel_tensor, int64_t total_mel_frames) { + inputs_[mel_input_idx_] = mel_tensor; + *signal_length_->GetTensorMutableData() = total_mel_frames; +} + +void NemotronEncoderSubState::UpdateCacheInputs() { + inputs_[cache_channel_input_idx_] = cache_.cache_last_channel.get(); + inputs_[cache_time_input_idx_] = cache_.cache_last_time.get(); + inputs_[cache_channel_len_input_idx_] = cache_.cache_last_channel_len.get(); +} + +DeviceSpan NemotronEncoderSubState::Run(int /*total_length*/, DeviceSpan& /*next_tokens*/, DeviceSpan /*next_indices*/) { + State::Run(*model_.session_encoder_); + return {}; +} + +NemotronPredictionSubState::NemotronPredictionSubState(const NemotronSpeechModel& model, const GeneratorParams& params) + : State{params, model}, + model_{model} { + auto& cfg = model_.cache_config_; + auto& allocator = model_.allocator_cpu_; + auto& device = *model_.p_device_; + + lstm_state_.Initialize(cfg, model_.session_info_, allocator, device); + + // Create targets and target_length tensors + auto targets_type = model_.session_info_.GetInputDataType(cfg.dec_in_targets); + auto targets_shape = std::array{1, 1}; + targets_ = OrtValue::CreateTensor(allocator, targets_shape, targets_type); + + auto tgt_len_type = model_.session_info_.GetInputDataType(cfg.dec_in_target_length); + auto tgt_len_shape = std::array{1}; + target_length_ = OrtValue::CreateTensor(allocator, tgt_len_shape, tgt_len_type); + *target_length_->GetTensorMutableData() = 1; + + // Register inputs + targets_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.dec_in_targets.c_str()); + inputs_.push_back(targets_.get()); + + target_length_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.dec_in_target_length.c_str()); + inputs_.push_back(target_length_.get()); + + lstm_hidden_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.dec_in_lstm_hidden.c_str()); + inputs_.push_back(lstm_state_.lstm_hidden_state.get()); + + lstm_cell_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.dec_in_lstm_cell.c_str()); + inputs_.push_back(lstm_state_.lstm_cell_state.get()); + + // Register outputs + output_names_.push_back(cfg.dec_out_outputs.c_str()); + outputs_.push_back(nullptr); + + output_names_.push_back(cfg.dec_out_prednet_lengths.c_str()); + outputs_.push_back(nullptr); + + output_names_.push_back(cfg.dec_out_lstm_hidden.c_str()); + outputs_.push_back(nullptr); + + output_names_.push_back(cfg.dec_out_lstm_cell.c_str()); + outputs_.push_back(nullptr); + + // Set run options from config + if (model_.config_->model.decoder.run_options.has_value()) { + State::SetRunOptions(model_.config_->model.decoder.run_options.value()); + } +} + +void NemotronPredictionSubState::UpdateInputs() { + *targets_->GetTensorMutableData() = lstm_state_.last_token; + inputs_[lstm_hidden_input_idx_] = lstm_state_.lstm_hidden_state.get(); + inputs_[lstm_cell_input_idx_] = lstm_state_.lstm_cell_state.get(); +} + +DeviceSpan NemotronPredictionSubState::Run(int /*total_length*/, DeviceSpan& /*next_tokens*/, DeviceSpan /*next_indices*/) { + State::Run(*model_.session_decoder_); + return {}; +} + +NemotronJoinerSubState::NemotronJoinerSubState(const NemotronSpeechModel& model, const GeneratorParams& params) + : State{params, model}, + model_{model} { + auto& cfg = model_.cache_config_; + + // Register inputs + encoder_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.join_in_encoder.c_str()); + inputs_.push_back(nullptr); // Set before each run + + decoder_input_idx_ = inputs_.size(); + input_names_.push_back(cfg.join_in_decoder.c_str()); + inputs_.push_back(nullptr); // Set before each run + + // Register output + output_names_.push_back(cfg.join_out_logits.c_str()); + outputs_.push_back(nullptr); + + // Set run options from config + if (model_.config_->model.joiner.run_options.has_value()) { + State::SetRunOptions(model_.config_->model.joiner.run_options.value()); + } +} + +void NemotronJoinerSubState::SetInputFrames(OrtValue* encoder_frame, OrtValue* decoder_frame) { + inputs_[encoder_input_idx_] = encoder_frame; + inputs_[decoder_input_idx_] = decoder_frame; +} + +DeviceSpan NemotronJoinerSubState::Run(int /*total_length*/, DeviceSpan& /*next_tokens*/, DeviceSpan /*next_indices*/) { + State::Run(*model_.session_joiner_); + return {}; +} + +NemotronSpeechState::NemotronSpeechState(const NemotronSpeechModel& model, + const GeneratorParams& params) + : State{params, model}, + nemotron_model_{model} { + cache_config_ = model.cache_config_; + + encoder_state_ = std::make_unique(model, params); + prediction_state_ = std::make_unique(model, params); + joiner_state_ = std::make_unique(model, params); + + // Pre-allocate encoder frame for joiner input + auto enc_out_type = model_.session_info_.GetOutputDataType(cache_config_.enc_out_encoded); + auto frame_shape = std::array{1, 1, cache_config_.hidden_dim}; + encoder_frame_ = OrtValue::CreateTensor(model_.allocator_cpu_, frame_shape, enc_out_type); +} + +NemotronSpeechState::~NemotronSpeechState() = default; + +DeviceSpan NemotronSpeechState::Run(int /*total_length*/, + DeviceSpan& /*next_tokens*/, + DeviceSpan /*next_indices*/) { + throw std::runtime_error( + "NemotronSpeechState::Run() is not used directly. " + "Use Generator::GenerateNextToken() with set_inputs."); +} + +void NemotronSpeechState::SetExtraInputs(const std::vector& extra_inputs) { + for (const auto& input : extra_inputs) { + if (input.name == Config::Defaults::AudioFeaturesName || input.name == cache_config_.enc_in_audio) { + current_mel_ = input.tensor; + need_encoder_run_ = true; + chunk_done_ = false; + } + } +} + +OrtValue* NemotronSpeechState::GetInput(const char* name) { + if (auto* val = encoder_state_->GetInput(name)) return val; + if (auto* val = prediction_state_->GetInput(name)) return val; + if (auto* val = joiner_state_->GetInput(name)) return val; + return State::GetInput(name); +} + +OrtValue* NemotronSpeechState::GetOutput(const char* name) { + if (auto* val = encoder_state_->GetOutput(name)) return val; + if (auto* val = prediction_state_->GetOutput(name)) return val; + if (auto* val = joiner_state_->GetOutput(name)) return val; + return State::GetOutput(name); +} + +void NemotronSpeechState::ResetStreamingState() { + auto& allocator = model_.allocator_cpu_; + auto& device = *model_.p_device_; + + encoder_state_->cache_.Reset(cache_config_, model_.session_info_, allocator, device); + encoder_state_->UpdateCacheInputs(); + encoder_state_->first_run_ = true; + + prediction_state_->lstm_state_.Reset(cache_config_, model_.session_info_, allocator, device); + prediction_state_->UpdateInputs(); + prediction_state_->first_run_ = true; + + joiner_state_->first_run_ = true; + + current_mel_.reset(); + encoded_output_.reset(); + encoded_len_ = 0; + time_step_ = 0; + symbol_step_ = 0; + need_encoder_run_ = false; + chunk_done_ = true; + last_tokens_.clear(); +} + +void NemotronSpeechState::RunEncoder() { + if (!current_mel_ || !current_mel_->ort_tensor_) + throw std::runtime_error("No mel input set. Call generator.set_model_input(\"audio_features\", mel) first."); + + OrtValue* mel_tensor = current_mel_->ort_tensor_.get(); + int64_t total_mel_frames = mel_tensor->GetTensorTypeAndShapeInfo()->GetShape()[1]; + + encoder_state_->SetMelInput(mel_tensor, total_mel_frames); + encoder_state_->UpdateCacheInputs(); + + DeviceSpan dummy_tokens; + encoder_state_->Run(0, dummy_tokens); + + // Grab encoder outputs + encoded_output_.reset(encoder_state_->outputs_[0]); + encoder_state_->outputs_[0] = nullptr; + encoded_len_ = *encoder_state_->outputs_[1]->GetTensorData(); + + // Cache outputs are already moved into encoder_state_->cache_ by Run() + // But State::Run uses output pointers differently - the outputs are written to by ORT + // We need to take ownership of the cache outputs + encoder_state_->cache_.cache_last_channel.reset(encoder_state_->outputs_[2]); + encoder_state_->outputs_[2] = nullptr; + encoder_state_->cache_.cache_last_time.reset(encoder_state_->outputs_[3]); + encoder_state_->outputs_[3] = nullptr; + encoder_state_->cache_.cache_last_channel_len.reset(encoder_state_->outputs_[4]); + encoder_state_->outputs_[4] = nullptr; + + current_mel_.reset(); +} + +std::span NemotronSpeechState::StepToken() { + if (need_encoder_run_) { + RunEncoder(); + need_encoder_run_ = false; + time_step_ = 0; + symbol_step_ = 0; + } + + last_tokens_.clear(); + + auto enc_shape = encoded_output_->GetTensorTypeAndShapeInfo()->GetShape(); + int64_t time_steps = std::min(enc_shape[1], encoded_len_); + int64_t hidden_dim = enc_shape[2]; + size_t frame_bytes = static_cast(hidden_dim) * sizeof(float); + + auto enc_span = ByteWrapTensor(*model_.p_device_, *encoded_output_); + auto frame_span = ByteWrapTensor(*model_.p_device_, *encoder_frame_); + auto& allocator = model_.allocator_cpu_; + + DeviceSpan dummy_tokens; + + while (time_step_ < time_steps) { + // Copy current encoder frame + auto src_frame = enc_span.subspan(static_cast(time_step_) * frame_bytes, frame_bytes); + frame_span.CopyFrom(src_frame); + + // Run prediction network + prediction_state_->UpdateInputs(); + prediction_state_->Run(0, dummy_tokens); + + // Reshape decoder output for joiner: [1, dim] -> [1, 1, dim] + auto dec_out_shape = prediction_state_->outputs_[0]->GetTensorTypeAndShapeInfo()->GetShape(); + auto decoder_frame_shape = std::array{1, 1, dec_out_shape[1]}; + auto dec_out_type = model_.session_info_.GetOutputDataType(cache_config_.dec_out_outputs); + auto decoder_frame = OrtValue::CreateTensor(allocator, decoder_frame_shape, dec_out_type); + ByteWrapTensor(*model_.p_device_, *decoder_frame) + .CopyFrom(ByteWrapTensor(*model_.p_device_, *prediction_state_->outputs_[0])); + + // Run joiner + joiner_state_->SetInputFrames(encoder_frame_.get(), decoder_frame.get()); + joiner_state_->Run(0, dummy_tokens); + + // Argmax over logits + const float* logits_data = joiner_state_->outputs_[0]->GetTensorData(); + auto logits_shape = joiner_state_->outputs_[0]->GetTensorTypeAndShapeInfo()->GetShape(); + int total_logits = 1; + for (auto d : logits_shape) total_logits *= static_cast(d); + + int best_token = 0; + float best_score = logits_data[0]; + for (int i = 1; i < total_logits; ++i) { + if (logits_data[i] > best_score) { + best_score = logits_data[i]; + best_token = i; + } + } + + if (best_token == cache_config_.blank_id) { + time_step_++; + symbol_step_ = 0; + continue; + } + + // Non-blank: emit token, update LSTM state from prediction outputs + prediction_state_->lstm_state_.last_token = best_token; + prediction_state_->lstm_state_.lstm_hidden_state.reset(prediction_state_->outputs_[2]); + prediction_state_->outputs_[2] = nullptr; + prediction_state_->lstm_state_.lstm_cell_state.reset(prediction_state_->outputs_[3]); + prediction_state_->outputs_[3] = nullptr; + + symbol_step_++; + if (symbol_step_ >= cache_config_.max_symbols_per_step) { + time_step_++; + symbol_step_ = 0; + } + + last_tokens_.push_back(static_cast(best_token)); + token_count_++; + return last_tokens_; + } + + // Exhausted all time steps + chunk_done_ = true; + return last_tokens_; +} + +} // namespace Generators diff --git a/src/models/nemotron_speech.h b/src/models/nemotron_speech.h new file mode 100644 index 0000000000..138a3f9dab --- /dev/null +++ b/src/models/nemotron_speech.h @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Nemotron Speech Streaming ASR model support. +#pragma once + +#include "model.h" +#include "audio_features.h" + +namespace Generators { + +struct NemotronCacheConfig { + // Encoder dimensions (from encoder.hidden_size / num_hidden_layers) + int num_encoder_layers{}; + int hidden_dim{}; + int left_context{}; + int conv_context{}; + + // Decoder LSTM dimensions (from decoder.hidden_size / num_hidden_layers) + int decoder_lstm_dim{}; + int decoder_lstm_layers{}; + + // Vocabulary + int vocab_size{}; + int blank_id{}; + + // Streaming chunk config + int chunk_frames{}; + int sample_rate{}; + int chunk_samples{}; + int subsampling_factor{}; + int max_symbols_per_step{}; + + // Mel spectrogram parameters + int num_mels{}; + int fft_size{}; + int hop_length{}; + int win_length{}; + float preemph{}; + float log_eps{}; + + // Pre-encode cache + int pre_encode_cache_size{}; + + // Encoder I/O names + std::string enc_in_audio; + std::string enc_in_length; + std::string enc_in_cache_channel; + std::string enc_in_cache_time; + std::string enc_in_cache_channel_len; + std::string enc_out_encoded; + std::string enc_out_length; + std::string enc_out_cache_channel; + std::string enc_out_cache_time; + std::string enc_out_cache_channel_len; + + // Decoder (prediction network) I/O names + std::string dec_in_targets; + std::string dec_in_target_length; + std::string dec_in_lstm_hidden; + std::string dec_in_lstm_cell; + std::string dec_out_outputs; + std::string dec_out_prednet_lengths; + std::string dec_out_lstm_hidden; + std::string dec_out_lstm_cell; + + // Joiner I/O names + std::string join_in_encoder; + std::string join_in_decoder; + std::string join_out_logits; + + void PopulateFromConfig(const Config& config); +}; + +/// Holds the rolling encoder cache state between streaming chunks. +struct NemotronEncoderCache { + std::unique_ptr cache_last_channel; + std::unique_ptr cache_last_time; + std::unique_ptr cache_last_channel_len; + + void Initialize(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device); + void Reset(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device); +}; + +/// Holds the RNNT decoder LSTM hidden states between decoding steps. +struct NemotronDecoderState { + std::unique_ptr lstm_hidden_state; + std::unique_ptr lstm_cell_state; + int last_token{0}; // Last emitted non-blank token (for autoregressive feedback) + + void Initialize(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device); + void Reset(const NemotronCacheConfig& cfg, const SessionInfo& session_info, OrtAllocator& allocator, DeviceInterface& device); +}; + +struct NemotronSpeechModel : Model { + NemotronSpeechModel(std::unique_ptr config, OrtEnv& ort_env); + + std::unique_ptr CreateState(DeviceSpan sequence_lengths, + const GeneratorParams& params) const override; + + // Three ONNX sessions: encoder, decoder (prediction network), joiner + std::unique_ptr session_encoder_; + std::unique_ptr session_decoder_; + std::unique_ptr session_joiner_; + + std::unique_ptr encoder_session_options_; + std::unique_ptr decoder_session_options_; + std::unique_ptr joiner_session_options_; + + NemotronCacheConfig cache_config_; +}; + +/// Sub-state for the streaming encoder. +struct NemotronEncoderSubState : State { + NemotronEncoderSubState(const NemotronSpeechModel& model, const GeneratorParams& params); + + DeviceSpan Run(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices = {}) override; + + /// Set mel input and update registered input pointers. + void SetMelInput(OrtValue* mel_tensor, int64_t total_mel_frames); + + /// Update registered input pointers after cache is modified. + void UpdateCacheInputs(); + + private: + friend struct NemotronSpeechState; + + const NemotronSpeechModel& model_; + NemotronEncoderCache cache_; + std::unique_ptr signal_length_; + + // Indices into inputs_/outputs_ vectors + size_t mel_input_idx_{}; + size_t length_input_idx_{}; + size_t cache_channel_input_idx_{}; + size_t cache_time_input_idx_{}; + size_t cache_channel_len_input_idx_{}; +}; + +/// Sub-state for the RNNT prediction network (decoder LSTM). +struct NemotronPredictionSubState : State { + NemotronPredictionSubState(const NemotronSpeechModel& model, const GeneratorParams& params); + + DeviceSpan Run(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices = {}) override; + + /// Update LSTM state input pointers before each run. + void UpdateInputs(); + + private: + friend struct NemotronSpeechState; + + const NemotronSpeechModel& model_; + NemotronDecoderState lstm_state_; + std::unique_ptr targets_; + std::unique_ptr target_length_; + + size_t targets_input_idx_{}; + size_t target_length_input_idx_{}; + size_t lstm_hidden_input_idx_{}; + size_t lstm_cell_input_idx_{}; +}; + +/// Sub-state for the joiner network. +struct NemotronJoinerSubState : State { + NemotronJoinerSubState(const NemotronSpeechModel& model, const GeneratorParams& params); + + DeviceSpan Run(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices = {}) override; + + /// Update encoder/decoder frame input pointers before each run. + void SetInputFrames(OrtValue* encoder_frame, OrtValue* decoder_frame); + + private: + friend struct NemotronSpeechState; + + const NemotronSpeechModel& model_; + + size_t encoder_input_idx_{}; + size_t decoder_input_idx_{}; +}; + +/// Orchestrator state for the full RNNT pipeline. +struct NemotronSpeechState : State { + NemotronSpeechState(const NemotronSpeechModel& model, const GeneratorParams& params); + ~NemotronSpeechState() override; + + DeviceSpan Run(int total_length, DeviceSpan& next_tokens, + DeviceSpan next_indices = {}) override; + + void SetExtraInputs(const std::vector& extra_inputs) override; + + std::span StepToken(); + bool IsChunkDone() const { return chunk_done_; } + std::span GetStepTokens() const { return last_tokens_; } + size_t TokenCount() const { return token_count_; } + void ResetStreamingState(); + + OrtValue* GetInput(const char* name) override; + OrtValue* GetOutput(const char* name) override; + + private: + const NemotronSpeechModel& nemotron_model_; + NemotronCacheConfig cache_config_; + + std::unique_ptr encoder_state_; + std::unique_ptr prediction_state_; + std::unique_ptr joiner_state_; + + // Current mel input + std::shared_ptr current_mel_; + + // Encoder output persisted across StepToken calls + std::unique_ptr encoded_output_; + int64_t encoded_len_{0}; + + // Pre-allocated encoder frame for joiner input + std::unique_ptr encoder_frame_; + + // Decoder state machine + int64_t time_step_{0}; + int symbol_step_{0}; + bool need_encoder_run_{false}; + bool chunk_done_{true}; + std::vector last_tokens_; + size_t token_count_{}; // Total tokens emitted across all chunks + + void RunEncoder(); +}; + +} // namespace Generators diff --git a/src/models/nemotron_streaming_processor.cpp b/src/models/nemotron_streaming_processor.cpp new file mode 100644 index 0000000000..f9a5cbfe90 --- /dev/null +++ b/src/models/nemotron_streaming_processor.cpp @@ -0,0 +1,138 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "../generators.h" +#include "nemotron_streaming_processor.h" + +namespace Generators { + +NemotronStreamingProcessor::NemotronStreamingProcessor(Model& model) + : model_{model} { + auto* nemotron_model = dynamic_cast(&model); + if (!nemotron_model) { + throw std::runtime_error("NemotronStreamingProcessor requires a nemotron_speech model type. Got: " + model.config_->model.type); + } + + cache_config_ = nemotron_model->cache_config_; + + if (cache_config_.pre_encode_cache_size <= 0) { + throw std::runtime_error("NemotronStreamingProcessor requires pre_encode_cache_size > 0. Got: " + + std::to_string(cache_config_.pre_encode_cache_size)); + } + + // Initialize mel extractor from config + nemo_mel::NemoMelConfig mel_cfg{ + cache_config_.num_mels, cache_config_.fft_size, + cache_config_.hop_length, cache_config_.win_length, + cache_config_.sample_rate, + cache_config_.preemph, cache_config_.log_eps}; + mel_extractor_ = nemo_mel::NemoStreamingMelExtractor{mel_cfg}; + + // Initialize mel pre-encode cache (time-major ring buffer, zeros for first chunk) + mel_pre_encode_cache_.assign( + static_cast(cache_config_.pre_encode_cache_size) * cache_config_.num_mels, 0.0f); + cache_pos_ = 0; +} + +NemotronStreamingProcessor::~NemotronStreamingProcessor() = default; + +std::unique_ptr NemotronStreamingProcessor::Process(const float* audio_data, size_t num_samples) { + // Append incoming audio to accumulation buffer + audio_buffer_.insert(audio_buffer_.end(), audio_data, audio_data + num_samples); + + const size_t chunk_size = static_cast(cache_config_.chunk_samples); + + // Process the first complete chunk available + if (audio_buffer_.size() >= chunk_size) { + auto mel = BuildMelTensor(audio_buffer_.data(), chunk_size); + audio_buffer_.erase(audio_buffer_.begin(), + audio_buffer_.begin() + static_cast(chunk_size)); + auto result = std::make_unique(); + result->emplace(Config::Defaults::AudioFeaturesName, std::make_shared(std::move(mel))); + return result; + } + + return nullptr; // Not enough audio yet +} + +std::unique_ptr NemotronStreamingProcessor::Flush() { + if (audio_buffer_.empty()) { + return nullptr; + } + + const size_t chunk_size = static_cast(cache_config_.chunk_samples); + audio_buffer_.resize(chunk_size, 0.0f); // Pad with silence + + auto mel = BuildMelTensor(audio_buffer_.data(), chunk_size); + audio_buffer_.clear(); + auto result = std::make_unique(); + result->emplace(Config::Defaults::AudioFeaturesName, std::make_shared(std::move(mel))); + return result; +} + +std::unique_ptr NemotronStreamingProcessor::BuildMelTensor(const float* audio_chunk, size_t chunk_samples) { + auto& allocator = model_.allocator_cpu_; + + // Compute mel spectrogram for this chunk: returns [num_mels, num_frames] (frequency-major) + auto [mel_data, num_frames] = mel_extractor_.Process(audio_chunk, chunk_samples); + + const int cache_size = cache_config_.pre_encode_cache_size; + const int num_mels = cache_config_.num_mels; + const int total_mel_frames = cache_size + num_frames; + + // Create output tensor: [1, total_mel_frames, num_mels] (time-major) + auto signal_type = model_.session_info_.GetInputDataType(cache_config_.enc_in_audio); + + // TODO: Optimize for GPU/CUDA later, CPU always expects float32. + if (signal_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + throw std::runtime_error("NemotronStreamingProcessor only supports float32 encoder input. Got type: " + std::to_string(signal_type)); + } + auto signal_shape = std::array{1, total_mel_frames, num_mels}; + auto processed_signal = OrtValue::CreateTensor(allocator, signal_shape, signal_type); + float* signal_data = processed_signal->GetTensorMutableData(); + + // Materialize cache frames from ring buffer (oldest-first starting at cache_pos_) + // Use at most 2 memcpys instead of per-frame copies + int first_run = std::min(cache_size - cache_pos_, cache_size); + std::memcpy(signal_data, + mel_pre_encode_cache_.data() + cache_pos_ * num_mels, + first_run * num_mels * sizeof(float)); + if (first_run < cache_size) { + std::memcpy(signal_data + first_run * num_mels, + mel_pre_encode_cache_.data(), + (cache_size - first_run) * num_mels * sizeof(float)); + } + + // Transpose mel from [num_mels, num_frames] directly into output tensor after cache + float* out_ptr = signal_data + cache_size * num_mels; + for (int t = 0; t < num_frames; ++t) { + for (int m = 0; m < num_mels; ++m) { + out_ptr[t * num_mels + m] = mel_data[m * num_frames + t]; + } + } + + // Update ring buffer with the last cache_size frames (or all if fewer) + int frames_to_cache = std::min(num_frames, cache_size); + const float* cache_src = out_ptr + (num_frames - frames_to_cache) * num_mels; + int frames_to_end = std::min(frames_to_cache, cache_size - cache_pos_); + std::memcpy(mel_pre_encode_cache_.data() + cache_pos_ * num_mels, + cache_src, + frames_to_end * num_mels * sizeof(float)); + if (frames_to_end < frames_to_cache) { + std::memcpy(mel_pre_encode_cache_.data(), + cache_src + frames_to_end * num_mels, + (frames_to_cache - frames_to_end) * num_mels * sizeof(float)); + } + cache_pos_ = (cache_pos_ + frames_to_cache) % cache_size; + + return processed_signal; +} + +std::unique_ptr CreateStreamingProcessor(Model& model) { + return std::make_unique(model); +} + +} // namespace Generators diff --git a/src/models/nemotron_streaming_processor.h b/src/models/nemotron_streaming_processor.h new file mode 100644 index 0000000000..405f2b79c0 --- /dev/null +++ b/src/models/nemotron_streaming_processor.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// NemotronStreamingProcessor, Nemotron-specific streaming mel spectrogram extraction. +#pragma once + +#include "streaming_processor.h" +#include "nemo_mel_spectrogram.h" +#include "nemotron_speech.h" + +namespace Generators { + +/// Nemotron-specific streaming processor that converts raw PCM audio into +/// mel spectrogram tensors for the cache-aware FastConformer encoder. +struct NemotronStreamingProcessor : StreamingProcessor { + explicit NemotronStreamingProcessor(Model& model); + ~NemotronStreamingProcessor() override; + + std::unique_ptr Process(const float* audio_data, size_t num_samples) override; + std::unique_ptr Flush() override; + + int GetChunkSamples() const { return cache_config_.chunk_samples; } + int GetSampleRate() const { return cache_config_.sample_rate; } + + private: + Model& model_; + NemotronCacheConfig cache_config_; + + // Log-mel feature extraction + nemo_mel::NemoStreamingMelExtractor mel_extractor_; + + // Mel pre-encode cache: ring buffer of last pre_encode_cache_size frames. + std::vector mel_pre_encode_cache_; + int cache_pos_{0}; + + // Audio accumulation buffer for incoming PCM samples + std::vector audio_buffer_; + + std::unique_ptr BuildMelTensor(const float* audio_chunk, size_t chunk_samples); +}; + +} // namespace Generators diff --git a/src/models/streaming_processor.h b/src/models/streaming_processor.h new file mode 100644 index 0000000000..de8eeec3f8 --- /dev/null +++ b/src/models/streaming_processor.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// StreamingProcessor, abstract base class for streaming processing. +#pragma once + +#include "model.h" + +namespace Generators { + +/// Abstract base class for streaming processors. +struct StreamingProcessor : LeakChecked { + virtual ~StreamingProcessor() = default; + + /// Feed raw data. + /// Returns a NamedTensors when a full chunk is ready, or nullptr if more data is needed. + virtual std::unique_ptr Process(const float* data, size_t num_samples) = 0; + + /// Flush remaining buffered data. + /// Returns final NamedTensors, or nullptr if buffer is empty. + virtual std::unique_ptr Flush() = 0; +}; + +std::unique_ptr CreateStreamingProcessor(Model& model); + +} // namespace Generators diff --git a/src/ort_genai.h b/src/ort_genai.h index d6ea59d3f1..567f495e98 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -886,3 +886,25 @@ inline int GetCurrentGpuDeviceId() { } } // namespace Oga + +struct OgaStreamingProcessor : OgaAbstract { + static std::unique_ptr Create(OgaModel& model) { + OgaStreamingProcessor* p; + OgaCheckResult(OgaCreateStreamingProcessor(&model, &p)); + return std::unique_ptr(p); + } + + std::unique_ptr Process(const float* audio_data, size_t num_samples) { + OgaNamedTensors* out; + OgaCheckResult(OgaStreamingProcessorProcess(this, audio_data, num_samples, &out)); + return std::unique_ptr(out); // May be nullptr if not enough audio + } + + std::unique_ptr Flush() { + OgaNamedTensors* out; + OgaCheckResult(OgaStreamingProcessorFlush(this, &out)); + return std::unique_ptr(out); + } + + static void operator delete(void* p) { OgaDestroyStreamingProcessor(reinterpret_cast(p)); } +}; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 41c70cdb3c..408be4991d 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -13,6 +13,8 @@ #include "search.h" #include "smartptrs.h" #include "engine/engine.h" +#include "models/streaming_processor.h" +#include "models/nemotron_speech.h" namespace Generators { @@ -60,6 +62,7 @@ struct OgaTokenizer : Generators::Tokenizer, OgaAbstract {}; struct OgaTokenizerStream : Generators::TokenizerStream, OgaAbstract {}; struct OgaEngine : Generators::Engine, OgaAbstract {}; struct OgaRequest : Generators::Request, OgaAbstract {}; +struct OgaStreamingProcessor : Generators::StreamingProcessor, OgaAbstract {}; // Helper function to return a shared pointer as a raw pointer. It won't compile if the types are wrong. // Exposed types that are internally owned by shared_ptrs inherit from ExternalRefCounted. Then we @@ -477,6 +480,13 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OgaResult* OGA_API_CALL OgaGenerator_GetNextTokens(const OgaGenerator* generator, const int32_t** out, size_t* out_count) { OGA_TRY + // For RNNT models, search_ is not used; return tokens from last StepToken + if (auto* speech_state = dynamic_cast(generator->state_.get())) { + auto tokens = speech_state->GetStepTokens(); + *out = tokens.data(); + *out_count = tokens.size(); + return nullptr; + } auto tokens = generator->search_->GetNextTokens().CopyDeviceToCpu(); *out = tokens.data(); *out_count = tokens.size(); @@ -1058,7 +1068,7 @@ OgaResult* OGA_API_CALL OgaRequestGetOpaqueData(OgaRequest* request, void** data void OGA_API_CALL OgaDestroyStringArray(OgaStringArray* string_array) { delete string_array; } void OGA_API_CALL OgaDestroyResult(OgaResult* p) { delete p; } -void OGA_API_CALL OgaDestroyString(const char* p) { delete p; } +void OGA_API_CALL OgaDestroyString(const char* p) { delete[] p; } void OGA_API_CALL OgaDestroySequences(OgaSequences* p) { delete p; } void OGA_API_CALL OgaDestroyConfig(OgaConfig* p) { delete p; } void OGA_API_CALL OgaDestroyModel(OgaModel* p) { p->ExternalRelease(); } @@ -1084,4 +1094,38 @@ void OGA_API_CALL OgaUnregisterExecutionProviderLibrary(const char* registration Ort::UnregisterExecutionProviderLibrary(&(Generators::GetOrtEnv()), registration_name); } +OgaResult* OGA_API_CALL OgaCreateStreamingProcessor(OgaModel* model, OgaStreamingProcessor** out) { + OGA_TRY + auto processor = Generators::CreateStreamingProcessor(*model); + *out = ReturnUnique(std::move(processor)); + return nullptr; + OGA_CATCH +} + +OgaResult* OGA_API_CALL OgaStreamingProcessorProcess(OgaStreamingProcessor* processor, const float* audio_data, size_t num_samples, OgaNamedTensors** out) { + OGA_TRY + auto result = processor->Process(audio_data, num_samples); + if (result) { + *out = ReturnUnique(std::move(result)); + } else { + *out = nullptr; + } + return nullptr; + OGA_CATCH +} + +OgaResult* OGA_API_CALL OgaStreamingProcessorFlush(OgaStreamingProcessor* processor, OgaNamedTensors** out) { + OGA_TRY + auto result = processor->Flush(); + if (result) { + *out = ReturnUnique(std::move(result)); + } else { + *out = nullptr; + } + return nullptr; + OGA_CATCH +} + +void OGA_API_CALL OgaDestroyStreamingProcessor(OgaStreamingProcessor* p) { delete p; } + } // extern "C" diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index de7d4fa484..9cc26eba88 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -79,6 +79,7 @@ typedef struct OgaStringArray OgaStringArray; typedef struct OgaAdapters OgaAdapters; typedef struct OgaEngine OgaEngine; typedef struct OgaRequest OgaRequest; +typedef struct OgaStreamingProcessor OgaStreamingProcessor; //! @} @@ -1146,6 +1147,39 @@ OGA_EXPORT void OGA_API_CALL OgaRegisterExecutionProviderLibrary(const char* reg */ OGA_EXPORT void OGA_API_CALL OgaUnregisterExecutionProviderLibrary(const char* registration_name); +/** + * \brief Creates a StreamingProcessor for mel spectrogram extraction from raw audio. + * \param[in] model The model to create the processor for (must be nemotron_speech type). + * \param[out] out Pointer to store the created StreamingProcessor instance. + * \return OgaResult on error, nullptr on success. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateStreamingProcessor(OgaModel* model, OgaStreamingProcessor** out); + +/** + * \brief Process a chunk of raw PCM audio and return a NamedTensors if a full chunk is ready. + * \param[in] processor The StreamingProcessor instance. + * \param[in] audio_data Pointer to float32 PCM audio samples (mono, model sample rate). + * \param[in] num_samples Number of audio samples. + * \param[out] out Pointer to store the NamedTensors. Set to nullptr if not enough audio yet. + * Caller must free with OgaDestroyNamedTensors. + * \return OgaResult on error, nullptr on success. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaStreamingProcessorProcess(OgaStreamingProcessor* processor, const float* audio_data, size_t num_samples, OgaNamedTensors** out); + +/** + * \brief Flush remaining buffered audio (pads with silence). + * \param[in] processor The StreamingProcessor instance. + * \param[out] out Pointer to store the NamedTensors. Set to nullptr if buffer was empty. + * \return OgaResult on error, nullptr on success. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaStreamingProcessorFlush(OgaStreamingProcessor* processor, OgaNamedTensors** out); + +/** + * \brief Destroy a StreamingProcessor instance. + * \param[in] processor The StreamingProcessor instance to destroy. + */ +OGA_EXPORT void OGA_API_CALL OgaDestroyStreamingProcessor(OgaStreamingProcessor* processor); + #ifdef __cplusplus } #endif diff --git a/src/python/python.cpp b/src/python/python.cpp index 51bab89a1b..175439d0b1 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -476,7 +476,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def_property_readonly("type", [](const OgaModel& model) -> std::string { return model.GetType().p_; }) .def_property_readonly( "device_type", [](const OgaModel& model) -> std::string { return model.GetDeviceType().p_; }, "The device type the model is running on") - .def("create_multimodal_processor", [](const OgaModel& model) { return OgaMultiModalProcessor::Create(model); }); + .def("create_multimodal_processor", [](const OgaModel& model) { return OgaMultiModalProcessor::Create(model); }) + .def("create_streaming_processor", [](OgaModel& model) { return OgaStreamingProcessor::Create(model); }, "Create a StreamingProcessor for mel spectrogram extraction from raw audio."); pybind11::class_(m, "Generator") .def(pybind11::init()) @@ -631,6 +632,36 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def("remove_request", &OgaEngine::Remove) .def("has_pending_requests", &OgaEngine::HasPendingRequests); + pybind11::class_(m, "StreamingProcessor") + .def(pybind11::init([](OgaModel& model) { return OgaStreamingProcessor::Create(model); }), + "Create a StreamingProcessor for mel spectrogram extraction.\n" + "The model must be of type 'nemotron_speech'.") + .def( + "process", + [](OgaStreamingProcessor& proc, pybind11::array_t audio_chunk) -> pybind11::object { + auto buf = audio_chunk.request(); + if (buf.ndim != 1) { + throw std::runtime_error("audio_chunk must be a 1-D array, got " + std::to_string(buf.ndim) + "-D"); + } + auto result = proc.Process(static_cast(buf.ptr), static_cast(buf.size)); + if (result) { + return pybind11::cast(std::move(result)); + } + return pybind11::none(); + }, + pybind11::arg("audio_chunk"), + "Feed raw PCM audio. Returns a NamedTensors if a full chunk is ready, or None if more audio is needed.") + .def( + "flush", + [](OgaStreamingProcessor& proc) -> pybind11::object { + auto result = proc.Flush(); + if (result) { + return pybind11::cast(std::move(result)); + } + return pybind11::none(); + }, + "Flush remaining buffered audio (pads with silence). Returns NamedTensors or None."); + m.def("set_log_options", &SetLogOptions); m.def("set_log_callback", &SetLogCallback); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index b0026435f9..86b93d85bf 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -1,7 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include #include // for memcmp +#include #include #include #include @@ -1362,3 +1365,148 @@ TEST(CAPITests, SetGuidance) { #endif } #endif + +#ifndef STREAMING_ASR_PATH +#define STREAMING_ASR_PATH MODEL_PATH "nemotron-speech-streaming" +#endif + +// Helper: if mel is not null, set inputs and run the decode loop +static void DecodeInputs(OgaGenerator& generator, OgaNamedTensors* mel) { + if (mel) { + generator.SetInputs(*mel); + while (!generator.IsDone()) { + generator.GenerateNextToken(); + } + } +} + +// Test creating a Generator + StreamingProcessor from a nemotron_speech model +TEST(CAPITests, StreamingASRCreate) { + if (!std::filesystem::exists(STREAMING_ASR_PATH)) + GTEST_SKIP() << "Streaming ASR model not found at " << STREAMING_ASR_PATH; + auto model = OgaModel::Create(STREAMING_ASR_PATH); + auto processor = OgaStreamingProcessor::Create(*model); + ASSERT_NE(processor, nullptr); + auto params = OgaGeneratorParams::Create(*model); + auto generator = OgaGenerator::Create(*model, *params); + ASSERT_NE(generator, nullptr); +} + +// Test transcribing silence (all zeros) via GenerateNextToken +TEST(CAPITests, StreamingASRTranscribeSilence) { + if (!std::filesystem::exists(STREAMING_ASR_PATH)) + GTEST_SKIP() << "Streaming ASR model not found at " << STREAMING_ASR_PATH; + auto model = OgaModel::Create(STREAMING_ASR_PATH); + auto processor = OgaStreamingProcessor::Create(*model); + auto params = OgaGeneratorParams::Create(*model); + auto generator = OgaGenerator::Create(*model, *params); + + constexpr size_t chunk_samples = 8960; + std::vector silence(chunk_samples, 0.0f); + + auto mel = processor->Process(silence.data(), silence.size()); + DecodeInputs(*generator, mel.get()); + SUCCEED(); +} + +// Test feeding multiple chunks and decoding via GenerateNextToken +TEST(CAPITests, StreamingASRMultipleChunks) { + if (!std::filesystem::exists(STREAMING_ASR_PATH)) + GTEST_SKIP() << "Streaming ASR model not found at " << STREAMING_ASR_PATH; + auto model = OgaModel::Create(STREAMING_ASR_PATH); + auto processor = OgaStreamingProcessor::Create(*model); + auto params = OgaGeneratorParams::Create(*model); + auto generator = OgaGenerator::Create(*model, *params); + + constexpr size_t chunk_samples = 8960; + std::vector silence(chunk_samples, 0.0f); + + for (int i = 0; i < 5; ++i) { + auto mel = processor->Process(silence.data(), silence.size()); + DecodeInputs(*generator, mel.get()); + } + SUCCEED(); +} + +// Test flush processes remaining buffered audio +TEST(CAPITests, StreamingASRFlush) { + if (!std::filesystem::exists(STREAMING_ASR_PATH)) + GTEST_SKIP() << "Streaming ASR model not found at " << STREAMING_ASR_PATH; + auto model = OgaModel::Create(STREAMING_ASR_PATH); + auto processor = OgaStreamingProcessor::Create(*model); + auto params = OgaGeneratorParams::Create(*model); + auto generator = OgaGenerator::Create(*model, *params); + + constexpr size_t chunk_samples = 8960; + std::vector silence(chunk_samples, 0.0f); + processor->Process(silence.data(), silence.size()); + + auto mel = processor->Flush(); + DecodeInputs(*generator, mel.get()); + SUCCEED(); +} + +// Test transcribing a synthetic sine wave via GenerateNextToken +TEST(CAPITests, StreamingASRSineWave) { + if (!std::filesystem::exists(STREAMING_ASR_PATH)) + GTEST_SKIP() << "Streaming ASR model not found at " << STREAMING_ASR_PATH; + auto model = OgaModel::Create(STREAMING_ASR_PATH); + auto processor = OgaStreamingProcessor::Create(*model); + auto params = OgaGeneratorParams::Create(*model); + auto generator = OgaGenerator::Create(*model, *params); + + constexpr size_t chunk_samples = 8960; + constexpr float sample_rate = 16000.0f; + constexpr float frequency = 440.0f; + + std::vector audio(chunk_samples); + for (size_t i = 0; i < chunk_samples; ++i) { + audio[i] = 0.5f * std::sin(2.0f * 3.14159265f * frequency * static_cast(i) / sample_rate); + } + + for (int i = 0; i < 4; ++i) { + auto mel = processor->Process(audio.data(), audio.size()); + ASSERT_NE(mel, nullptr); + DecodeInputs(*generator, mel.get()); + } + + auto flush_mel = processor->Flush(); + DecodeInputs(*generator, flush_mel.get()); + SUCCEED(); +} + +// Test raw C API for StreamingProcessor + Generator +TEST(CAPITests, StreamingASRRawCAPI) { + if (!std::filesystem::exists(STREAMING_ASR_PATH)) + GTEST_SKIP() << "Streaming ASR model not found at " << STREAMING_ASR_PATH; + OgaModel* model = nullptr; + ASSERT_EQ(OgaCreateModel(STREAMING_ASR_PATH, &model), nullptr); + ASSERT_NE(model, nullptr); + + OgaStreamingProcessor* processor = nullptr; + ASSERT_EQ(OgaCreateStreamingProcessor(model, &processor), nullptr); + ASSERT_NE(processor, nullptr); + + OgaGeneratorParams* params = nullptr; + ASSERT_EQ(OgaCreateGeneratorParams(model, ¶ms), nullptr); + OgaGenerator* generator = nullptr; + ASSERT_EQ(OgaCreateGenerator(model, params, &generator), nullptr); + ASSERT_NE(generator, nullptr); + + constexpr size_t chunk_samples = 8960; + std::vector silence(chunk_samples, 0.0f); + + OgaNamedTensors* inputs = nullptr; + ASSERT_EQ(OgaStreamingProcessorProcess(processor, silence.data(), silence.size(), &inputs), nullptr); + ASSERT_NE(inputs, nullptr); + ASSERT_EQ(OgaGenerator_SetInputs(generator, inputs), nullptr); + while (!OgaGenerator_IsDone(generator)) { + ASSERT_EQ(OgaGenerator_GenerateNextToken(generator), nullptr); + } + OgaDestroyNamedTensors(inputs); + + OgaDestroyGenerator(generator); + OgaDestroyGeneratorParams(params); + OgaDestroyStreamingProcessor(processor); + OgaDestroyModel(model); +} diff --git a/test/python/conftest.py b/test/python/conftest.py index 8e5fccd745..78d4a77aa9 100644 --- a/test/python/conftest.py +++ b/test/python/conftest.py @@ -88,6 +88,16 @@ def path_for_model(request): return functools.partial(get_path_for_model, request.config.getoption("--test_models")) +@pytest.fixture +def nemotron_speech_model_path(request): + """Return the path to a nemotron_speech model directory, or skip if not available.""" + test_data = request.config.getoption("--test_models") + model_path = os.path.join(test_data, "nemotron-speech-streaming") + if not os.path.exists(model_path): + pytest.skip(f"Nemotron speech model not found at {model_path}") + return model_path + + @pytest.fixture def test_data_path(request): return request.config.getoption("--test_models") diff --git a/test/python/test_onnxruntime_genai_api.py b/test/python/test_onnxruntime_genai_api.py index 602570325a..4738a577cd 100644 --- a/test/python/test_onnxruntime_genai_api.py +++ b/test/python/test_onnxruntime_genai_api.py @@ -961,3 +961,186 @@ def test_audio_preprocessing_multiple_audios(test_data_path, relative_model_path decoder_prompt_tokens = ["<|startoftranscript|>", "<|en|>", "<|transcribe|>", "<|notimestamps|>"] prompts = ["".join(decoder_prompt_tokens)] * batch_size _ = processor(prompts, audios=audios) + + +def test_streaming_asr_create(nemotron_speech_model_path): + """Test that Generator + StreamingProcessor can be created from a nemotron_speech model.""" + model = og.Model(nemotron_speech_model_path) + processor = og.StreamingProcessor(model) + assert processor is not None + params = og.GeneratorParams(model) + generator = og.Generator(model, params) + assert generator is not None + + +def _load_streaming_config(model_path): + """Read sample_rate and chunk_samples from genai_config.json.""" + import json + config_path = os.path.join(model_path, "genai_config.json") + with open(config_path, "r") as f: + config = json.load(f) + return config["model"]["sample_rate"], config["model"]["chunk_samples"] + + +def _decode_inputs(generator, inputs, tokenizer_stream=None): + """Common helper: set inputs on the generator and decode all tokens. + + Returns the decoded text if tokenizer_stream is provided, otherwise empty string. + """ + if inputs is None: + return "" + generator.set_inputs(inputs) + text = "" + while not generator.is_done(): + generator.generate_next_token() + tokens = generator.get_next_tokens() + if tokenizer_stream is not None: + for token in tokens: + token_text = tokenizer_stream.decode(token) + if token_text: + text += token_text + return text + + +def test_streaming_asr_transcribe_silence(nemotron_speech_model_path): + """Test transcribing a chunk of silence (all zeros) does not crash.""" + sample_rate, chunk_samples = _load_streaming_config(nemotron_speech_model_path) + model = og.Model(nemotron_speech_model_path) + processor = og.StreamingProcessor(model) + tokenizer = og.Tokenizer(model) + tokenizer_stream = tokenizer.create_stream() + params = og.GeneratorParams(model) + generator = og.Generator(model, params) + + silence = np.zeros(chunk_samples, dtype=np.float32) + mel = processor.process(silence) + text = _decode_inputs(generator, mel, tokenizer_stream) + assert isinstance(text, str) + + +def test_streaming_asr_flush(nemotron_speech_model_path): + """Test that flush processes remaining buffered audio.""" + sample_rate, chunk_samples = _load_streaming_config(nemotron_speech_model_path) + model = og.Model(nemotron_speech_model_path) + processor = og.StreamingProcessor(model) + params = og.GeneratorParams(model) + generator = og.Generator(model, params) + + silence = np.zeros(chunk_samples, dtype=np.float32) + processor.process(silence) + + mel = processor.flush() + _decode_inputs(generator, mel) + + +def test_streaming_asr_sine_wave(nemotron_speech_model_path): + """Test transcribing a synthetic sine wave (non-trivial mel features).""" + sample_rate, chunk_samples = _load_streaming_config(nemotron_speech_model_path) + model = og.Model(nemotron_speech_model_path) + processor = og.StreamingProcessor(model) + tokenizer = og.Tokenizer(model) + tokenizer_stream = tokenizer.create_stream() + params = og.GeneratorParams(model) + generator = og.Generator(model, params) + + frequency = 440.0 # A4 note + + # Generate 440Hz sine wave + t = np.arange(chunk_samples, dtype=np.float32) / sample_rate + audio = (0.5 * np.sin(2.0 * np.pi * frequency * t)).astype(np.float32) + + transcript = "" + for _ in range(4): + mel = processor.process(audio) + transcript += _decode_inputs(generator, mel, tokenizer_stream) + + mel = processor.flush() + transcript += _decode_inputs(generator, mel, tokenizer_stream) + + assert isinstance(transcript, str) + + +def test_streaming_asr_config_model_type(nemotron_speech_model_path): + """Test that a nemotron_speech model reports the correct type.""" + model = og.Model(nemotron_speech_model_path) + assert model.type == "nemotron_speech" + + +def _word_error_rate(reference: str, hypothesis: str) -> float: + """Compute Word Error Rate (WER) using edit distance on word sequences.""" + import re + + def normalize(text): + text = re.sub(r"[^\w\s]", "", text.lower()) + return text.split() + + r = normalize(reference) + h = normalize(hypothesis) + d = [[0] * (len(h) + 1) for _ in range(len(r) + 1)] + for i in range(len(r) + 1): + d[i][0] = i + for j in range(len(h) + 1): + d[0][j] = j + for i in range(1, len(r) + 1): + for j in range(1, len(h) + 1): + if r[i - 1] == h[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + return d[len(r)][len(h)] / max(len(r), 1) + + +def test_streaming_asr_transcription_quality(nemotron_speech_model_path, test_data_path): + """Test that transcription of a known audio file has acceptable WER.""" + try: + import soundfile as sf + except ImportError: + pytest.skip("soundfile not installed") + return + + audio_path = os.path.join(test_data_path, "audios", "1272-141231-0002.mp3") + if not os.path.exists(audio_path): + pytest.skip(f"Test audio not found: {audio_path}") + + # Load audio as float32 mono, resample to model's sample rate + audio, sr = sf.read(audio_path, dtype="float32") + if len(audio.shape) > 1: + audio = audio.mean(axis=1) + sample_rate, chunk_samples = _load_streaming_config(nemotron_speech_model_path) + if sr != sample_rate: + try: + import scipy.signal + num_samples = int(len(audio) * sample_rate / sr) + audio = scipy.signal.resample(audio, num_samples).astype(np.float32) + except ImportError: + pytest.skip(f"Audio is {sr}Hz and scipy not available for resampling") + + # Transcribe using Generator + StreamingProcessor + model = og.Model(nemotron_speech_model_path) + processor = og.StreamingProcessor(model) + tokenizer = og.Tokenizer(model) + tokenizer_stream = tokenizer.create_stream() + params = og.GeneratorParams(model) + generator = og.Generator(model, params) + + transcript = "" + for start in range(0, len(audio), chunk_samples): + chunk = audio[start : start + chunk_samples].astype(np.float32) + mel = processor.process(chunk) + transcript += _decode_inputs(generator, mel, tokenizer_stream) + + mel = processor.flush() + transcript += _decode_inputs(generator, mel, tokenizer_stream) + + reference = ( + "the cut on his chest still dripping blood the ache of his overstrained eyes " + "even the soaring arena around him with the thousands of spectators were " + "trivialities not worth thinking about" + ) + + wer = _word_error_rate(reference, transcript) + assert wer < 0.15, ( + f"WER too high: {wer:.1%}\n" + f" Reference: {reference}\n" + f" Hypothesis: {transcript.lower()}" + ) \ No newline at end of file diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py index 5e82e4fdca..fcc16529f0 100644 --- a/test/python/test_onnxruntime_genai_e2e.py +++ b/test/python/test_onnxruntime_genai_e2e.py @@ -158,6 +158,34 @@ def run_tool_calling(): run_subprocess(command, cwd=cwd, log=log).check_returncode() +def run_nemotron_speech(): + """Run Nemotron Speech Streaming ASR E2E test by invoking the nemotron_speech.py example.""" + log.debug("Running Nemotron Speech Python E2E Test") + + # Look for nemotron speech model in test_models directory + cwd = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(cwd, "..", "test_models", "nemotron-speech-streaming") + if not os.path.exists(model_path): + log.info(f"Nemotron speech model not found at {model_path}, skipping E2E test.") + return + + # Look for a test audio file + audio_path = os.path.join(cwd, "..", "test_models", "audios", "1272-141231-0002.mp3") + if not os.path.exists(audio_path): + log.info(f"Test audio file not found at {audio_path}, skipping E2E test.") + return + + command = [ + sys.executable, + os.path.join(cwd, "..", "..", "examples", "python", "nemotron_speech.py"), + "--model_path", + model_path, + "--audio_file", + audio_path, + ] + run_subprocess(command, cwd=cwd, log=log).check_returncode() + + def get_args(): parser = argparse.ArgumentParser() @@ -187,5 +215,8 @@ def get_args(): # Run Whisper E2E tests run_whisper() + # Run Nemotron Speech E2E tests + run_nemotron_speech() + # Run tool calling E2E tests run_tool_calling()