Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
31d6779
nemotron support
nenad1002 Feb 10, 2026
9026781
ONNX 2 good version
nenad1002 Feb 11, 2026
8c6f4ed
Nemotron support
nenad1002 Feb 12, 2026
9dd6212
Support 4
nenad1002 Feb 12, 2026
8b0de45
First stream
nenad1002 Feb 12, 2026
0d83168
Overlap support
nenad1002 Feb 12, 2026
d2ff912
Nemotron support stream 3
nenad1002 Feb 12, 2026
c7ed0c9
Mi fix
nenad1002 Feb 13, 2026
b83b84f
Move mel stuff to separate file
nenad1002 Feb 13, 2026
ff275b4
Remove mel spectogram
nenad1002 Feb 13, 2026
32001f1
Revert non-needed changes
nenad1002 Feb 17, 2026
131db0c
Make sure genai_config.json defines model params
nenad1002 Feb 18, 2026
b262003
Point to latest extensions
nenad1002 Feb 20, 2026
5cc511d
Add tests
nenad1002 Feb 20, 2026
5cf0c59
Add a better test
nenad1002 Feb 20, 2026
f670d36
Remove text tokenizer and sr to genaiconfig
nenad1002 Feb 20, 2026
6e23d87
Remove dead code
nenad1002 Feb 20, 2026
fd47344
Abstract streaming ASR class
nenad1002 Feb 20, 2026
98d6e54
remove processor
nenad1002 Feb 21, 2026
06d05d0
Fix merge conflict
nenad1002 Mar 2, 2026
46a166d
Clean more code
nenad1002 Mar 2, 2026
8a5e912
Clean up examples
nenad1002 Mar 2, 2026
e10086f
Performance optimizations
nenad1002 Mar 2, 2026
5eb10ff
More cleaning
nenad1002 Mar 2, 2026
89b8bd5
Try removing warning
nenad1002 Mar 2, 2026
880143b
Add flag to tests
nenad1002 Mar 2, 2026
092d212
fix formatting
nenad1002 Mar 3, 2026
67e649c
Resolve Copilot comments
nenad1002 Mar 3, 2026
5be81c9
Fix formatting issue
nenad1002 Mar 3, 2026
b3c6411
Merge branch 'main' into nebanfic/nemotron-support-stream-3
nenad1002 Mar 3, 2026
165037b
Remove soundfile
nenad1002 Mar 3, 2026
5097afd
Remove dead tokenzier code
nenad1002 Mar 5, 2026
98e81b7
Adjust genai config to our exported models
nenad1002 Mar 5, 2026
0a6d87b
Resolve more comments
nenad1002 Mar 5, 2026
70f4e23
Avoid memset, memcpy and manual copy on GPU and whenever possible, ri…
nenad1002 Mar 5, 2026
4d8a0f5
Add consistency
nenad1002 Mar 5, 2026
96dafca
Big improvement - cache locality for frames
nenad1002 Mar 6, 2026
2499ab4
Csharp support
nenad1002 Mar 6, 2026
51a61c7
Add a check to the factory for StreamingASR
nenad1002 Mar 6, 2026
9e28df9
nemotron generator
nenad1002 Mar 6, 2026
8a1bef0
remove ProcessChunk from model.h
nenad1002 Mar 6, 2026
154b5aa
remove generate_next_tokens()
nenad1002 Mar 6, 2026
571f300
Rename processor
nenad1002 Mar 6, 2026
f61fc0a
C# sample and remove unnecessary files
nenad1002 Mar 9, 2026
8596fc1
Fix all
nenad1002 Mar 9, 2026
b762766
more fixes
nenad1002 Mar 9, 2026
e2ab1e7
samples change
nenad1002 Mar 9, 2026
4059341
Introduce NamedTensors on streaming processor
nenad1002 Mar 10, 2026
ca9a9f3
Remove speech section in genai_config
nenad1002 Mar 10, 2026
282a9f0
Reverse NativeMethods.cs formatting
nenad1002 Mar 10, 2026
dc46428
Some refactoring
nenad1002 Mar 10, 2026
7c636ba
Make streaming processor abstract class
nenad1002 Mar 10, 2026
33f809b
set_inputs
nenad1002 Mar 10, 2026
66ee360
Copilot suggestions
nenad1002 Mar 11, 2026
27ce2e5
Examples changes
nenad1002 Mar 12, 2026
8723d3b
More comments resolved
nenad1002 Mar 12, 2026
96ce812
SubStates
nenad1002 Mar 12, 2026
df8cb9b
More changes
nenad1002 Mar 12, 2026
c5ed7df
Resolvimg more comments
nenad1002 Mar 12, 2026
02c5fde
Mass copy
nenad1002 Mar 12, 2026
101113d
Copilot fixes
nenad1002 Mar 12, 2026
640e9af
Merge conflict fix
nenad1002 Mar 12, 2026
7d023ef
Potential fix for code scanning alert no. 798: Unused local variable
nenad1002 Mar 12, 2026
7283dd1
Fix clang
nenad1002 Mar 12, 2026
a3f77e4
Run clang
nenad1002 Mar 12, 2026
658e8de
fix tests
nenad1002 Mar 12, 2026
78b84a2
Resolve comments
nenad1002 Mar 13, 2026
ff890a9
Add C++ example
nenad1002 Mar 16, 2026
9473ab8
Semicolon on another line
nenad1002 Mar 16, 2026
8e683f3
Add C# sample readme
nenad1002 Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;087953cde6149e423c6848c40c3791264272706c
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;45a76bce69f874ad980933504700fc110ebf1ecb

# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
llguidance;https://github.com/microsoft/llguidance.git;94fa39128ef184ffeda33845f6d333f332a34b4d
Expand Down
99 changes: 99 additions & 0 deletions examples/python/nemotron_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import argparse
import os
import sys
import time
import re
import numpy as np
import onnxruntime_genai as og

SAMPLE_RATE = 16000
CHUNK_SAMPLES = 8960
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
CHUNK_DURATION = CHUNK_SAMPLES / SAMPLE_RATE


def load_audio(audio_path):
import soundfile as sf
audio, sr = sf.read(audio_path, dtype="float32")
if len(audio.shape) > 1:
audio = audio.mean(axis=1)
if sr != SAMPLE_RATE:
import scipy.signal
num_samples = int(len(audio) * SAMPLE_RATE / sr)
audio = scipy.signal.resample(audio, num_samples).astype(np.float32)
return audio


def load_tokenizer(model_path):
import sentencepiece as spm
path = os.path.join(model_path, "tokenizer.model")
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
if not os.path.exists(path):
return None
sp = spm.SentencePieceProcessor()
sp.Load(path)
return sp


def parse_token_ids(raw_text):
return [int(m.group(1)) for m in re.finditer(r'<(\d+)>', raw_text)]
Comment thread
nenad1002 marked this conversation as resolved.
Outdated


def simulate_microphone(model_path, audio_path):
audio = load_audio(audio_path)
duration = len(audio) / SAMPLE_RATE
num_chunks = (len(audio) + CHUNK_SAMPLES - 1) // CHUNK_SAMPLES
print(f"Audio: {duration:.1f}s | {num_chunks} chunks × {CHUNK_DURATION*1000:.0f}ms")

config = og.Config(model_path)
model = og.Model(config)
sp = load_tokenizer(model_path)
asr = og.StreamingASR(model)
Comment thread
nenad1002 marked this conversation as resolved.
Outdated

print("-" * 60)
stream_start = time.time()

for i in range(0, len(audio), CHUNK_SAMPLES):
chunk = audio[i:i + CHUNK_SAMPLES]
if len(chunk) < CHUNK_SAMPLES:
chunk = np.pad(chunk, (0, CHUNK_SAMPLES - len(chunk)))
chunk = chunk.astype(np.float32)
raw_text = asr.transcribe_chunk(chunk)
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
if raw_text:
print(raw_text, end="", flush=True)

for _ in range(4):
silence = np.zeros(CHUNK_SAMPLES, dtype=np.float32)
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
raw_text = asr.transcribe_chunk(silence)
if raw_text:
print(raw_text, end="", flush=True)

total_wall = time.time() - stream_start

full_raw = asr.get_transcript()
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
if sp:
all_ids = parse_token_ids(full_raw)
final_text = sp.Decode(all_ids) if all_ids else full_raw
else:
final_text = full_raw

print(f"\n{'=' * 60}")
print(f" {final_text.strip()}")
print(f"{'=' * 60}")
print(f" Audio: {duration:.2f}s | Wall: {total_wall:.2f}s | RTF: {duration/total_wall:.2f}x")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--audio_file", type=str, required=True)
args = parser.parse_args()
if not os.path.exists(args.audio_file):
print(f"Error: {args.audio_file} not found")
sys.exit(1)
simulate_microphone(args.model_path, args.audio_file)


if __name__ == "__main__":
main()
139 changes: 137 additions & 2 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,14 @@ struct DecoderInputs_Element : JSON::Element {
v_.past_sequence_lengths = JSON::Get<std::string_view>(value);
} else if (name == "block_table") {
v_.block_table = JSON::Get<std::string_view>(value);
} else if (name == "targets") {
v_.targets = JSON::Get<std::string_view>(value);
} else if (name == "target_length") {
v_.target_length = JSON::Get<std::string_view>(value);
} else if (name == "states_1") {
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
v_.states_1 = JSON::Get<std::string_view>(value);
} else if (name == "states_2") {
v_.states_2 = JSON::Get<std::string_view>(value);
} else {
throw JSON::unknown_value_error{};
}
Expand All @@ -340,6 +348,14 @@ struct DecoderOutputs_Element : JSON::Element {
v_.output_cross_qk_names = JSON::Get<std::string_view>(value);
} else if (name == "rnn_states") {
v_.rnn_states = JSON::Get<std::string_view>(value);
} else if (name == "outputs") {
v_.outputs = JSON::Get<std::string_view>(value);
} else if (name == "prednet_lengths") {
v_.prednet_lengths = JSON::Get<std::string_view>(value);
} else if (name == "states_1") {
v_.states_1 = JSON::Get<std::string_view>(value);
} else if (name == "states_2") {
v_.states_2 = JSON::Get<std::string_view>(value);
} else {
throw JSON::unknown_value_error{};
}
Expand Down Expand Up @@ -557,10 +573,10 @@ struct Decoder_Element : JSON::Element {
v_.hidden_size = static_cast<int>(JSON::Get<double>(value));
} else if (name == "num_attention_heads") {
v_.num_attention_heads = static_cast<int>(JSON::Get<double>(value));
} else if (name == "num_key_value_heads") {
v_.num_key_value_heads = static_cast<int>(JSON::Get<double>(value));
} else if (name == "num_hidden_layers") {
v_.num_hidden_layers = static_cast<int>(JSON::Get<double>(value));
} else if (name == "num_key_value_heads") {
v_.num_key_value_heads = static_cast<int>(JSON::Get<double>(value));
} else if (name == "head_size") {
v_.head_size = static_cast<int>(JSON::Get<double>(value));
} else {
Expand Down Expand Up @@ -827,6 +843,50 @@ struct Speech_Element : JSON::Element {
v_.config_filename = JSON::Get<std::string_view>(value);
} else if (name == "adapter_filename") {
v_.adapter_filename = JSON::Get<std::string_view>(value);
} else if (name == "num_mels") {
v_.num_mels = static_cast<int>(JSON::Get<double>(value));
} else if (name == "fft_size") {
v_.fft_size = static_cast<int>(JSON::Get<double>(value));
} else if (name == "hop_length") {
v_.hop_length = static_cast<int>(JSON::Get<double>(value));
} else if (name == "win_length") {
v_.win_length = static_cast<int>(JSON::Get<double>(value));
} else if (name == "preemph") {
v_.preemph = static_cast<float>(JSON::Get<double>(value));
} else if (name == "log_eps") {
v_.log_eps = static_cast<float>(JSON::Get<double>(value));
} else if (name == "subsampling_factor") {
v_.subsampling_factor = static_cast<int>(JSON::Get<double>(value));
} else if (name == "left_context") {
v_.left_context = static_cast<int>(JSON::Get<double>(value));
} else if (name == "conv_context") {
v_.conv_context = static_cast<int>(JSON::Get<double>(value));
} else if (name == "pre_encode_cache_size") {
v_.pre_encode_cache_size = static_cast<int>(JSON::Get<double>(value));
} else if (name == "sample_rate") {
v_.sample_rate = static_cast<int>(JSON::Get<double>(value));
} else if (name == "chunk_samples") {
v_.chunk_samples = static_cast<int>(JSON::Get<double>(value));
} else if (name == "blank_id") {
v_.blank_id = static_cast<int>(JSON::Get<double>(value));
} else if (name == "max_symbols_per_step") {
v_.max_symbols_per_step = static_cast<int>(JSON::Get<double>(value));
} else if (name == "enc_in_length") {
v_.enc_in_length = JSON::Get<std::string_view>(value);
} else if (name == "enc_in_cache_channel") {
v_.enc_in_cache_channel = JSON::Get<std::string_view>(value);
} else if (name == "enc_in_cache_time") {
v_.enc_in_cache_time = JSON::Get<std::string_view>(value);
} else if (name == "enc_in_cache_channel_len") {
v_.enc_in_cache_channel_len = JSON::Get<std::string_view>(value);
} else if (name == "enc_out_length") {
v_.enc_out_length = JSON::Get<std::string_view>(value);
} else if (name == "enc_out_cache_channel") {
v_.enc_out_cache_channel = JSON::Get<std::string_view>(value);
} else if (name == "enc_out_cache_time") {
v_.enc_out_cache_time = JSON::Get<std::string_view>(value);
} else if (name == "enc_out_cache_channel_len") {
v_.enc_out_cache_channel_len = JSON::Get<std::string_view>(value);
} else {
throw JSON::unknown_value_error{};
}
Expand Down Expand Up @@ -860,6 +920,77 @@ struct Speech_Element : JSON::Element {
SpeechOutputs_Element outputs_{v_.outputs};
};

struct JoinerInputs_Element : JSON::Element {
explicit JoinerInputs_Element(Config::Model::Joiner::Inputs& v) : v_{v} {}

void OnValue(std::string_view name, JSON::Value value) override {
if (name == "encoder_outputs") {
v_.encoder_outputs = JSON::Get<std::string_view>(value);
} else if (name == "decoder_outputs") {
v_.decoder_outputs = JSON::Get<std::string_view>(value);
} else {
throw JSON::unknown_value_error{};
}
}

private:
Config::Model::Joiner::Inputs& v_;
};

struct JoinerOutputs_Element : JSON::Element {
explicit JoinerOutputs_Element(Config::Model::Joiner::Outputs& v) : v_{v} {}

void OnValue(std::string_view name, JSON::Value value) override {
if (name == "logits") {
v_.logits = JSON::Get<std::string_view>(value);
} else {
throw JSON::unknown_value_error{};
}
}

private:
Config::Model::Joiner::Outputs& v_;
};

struct Joiner_Element : JSON::Element {
explicit Joiner_Element(Config::Model::Joiner& v) : v_{v} {}

void OnValue(std::string_view name, JSON::Value value) override {
if (name == "filename") {
v_.filename = JSON::Get<std::string_view>(value);
} else {
throw JSON::unknown_value_error{};
}
}

Element& OnObject(std::string_view name) override {
if (name == "session_options") {
v_.session_options = Config::SessionOptions{};
session_options_ = std::make_unique<SessionOptions_Element>(*v_.session_options);
return *session_options_;
}
if (name == "run_options") {
v_.run_options = Config::RunOptions{};
run_options_ = std::make_unique<RunOptions_Element>(*v_.run_options);
return *run_options_;
}
if (name == "inputs") {
return inputs_;
}
if (name == "outputs") {
return outputs_;
}
throw JSON::unknown_value_error{};
}

private:
Config::Model::Joiner& v_;
std::unique_ptr<SessionOptions_Element> session_options_;
std::unique_ptr<RunOptions_Element> run_options_;
JoinerInputs_Element inputs_{v_.inputs};
JoinerOutputs_Element outputs_{v_.outputs};
};

struct EmbeddingInputs_Element : JSON::Element {
explicit EmbeddingInputs_Element(Config::Model::Embedding::Inputs& v) : v_{v} {}

Expand Down Expand Up @@ -986,6 +1117,9 @@ struct Model_Element : JSON::Element {
if (name == "speech") {
return speech_;
}
if (name == "joiner") {
return joiner_;
}
throw JSON::unknown_value_error{};
}

Expand All @@ -997,6 +1131,7 @@ struct Model_Element : JSON::Element {
Vision_Element vision_{v_.vision};
Embedding_Element embedding_{v_.embedding};
Speech_Element speech_{v_.speech};
Joiner_Element joiner_{v_.joiner};
};

int SafeDoubleToInt(double x, std::string_view name) {
Expand Down
53 changes: 53 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,32 @@ struct Config {
std::string config_filename{"audio_processor_config.json"};
std::optional<std::string> adapter_filename{};

// Mel spectrogram / streaming ASR parameters
int num_mels{};
int fft_size{};
int hop_length{};
int win_length{};
float preemph{};
float log_eps{};
int subsampling_factor{};
int left_context{};
int conv_context{};
int pre_encode_cache_size{};
int sample_rate{};
int chunk_samples{};
int blank_id{};
int max_symbols_per_step{};

// Cache-aware streaming encoder I/O names
std::string enc_in_length{"length"};
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
std::string enc_in_cache_channel{"cache_last_channel"};
std::string enc_in_cache_time{"cache_last_time"};
std::string enc_in_cache_channel_len{"cache_last_channel_len"};
std::string enc_out_length{"encoded_lengths"};
std::string enc_out_cache_channel{"cache_last_channel_next"};
std::string enc_out_cache_time{"cache_last_time_next"};
std::string enc_out_cache_channel_len{"cache_last_channel_next_len"};

struct Inputs {
std::string audio_embeds{Defaults::AudioEmbedsName};
std::string attention_mask{Defaults::AudioAttentionMaskName};
Expand All @@ -214,6 +240,21 @@ struct Config {
} outputs;
} speech;

struct Joiner {
std::string filename;
std::optional<SessionOptions> session_options;
std::optional<RunOptions> run_options;

struct Inputs {
std::string encoder_outputs{"encoder_outputs"};
std::string decoder_outputs{"decoder_outputs"};
} inputs;

struct Outputs {
std::string logits{"outputs"};
} outputs;
} joiner;

struct Decoder {
std::string filename;
SessionOptions session_options;
Expand Down Expand Up @@ -255,6 +296,12 @@ struct Config {
std::string cumulative_sequence_lengths{Defaults::CumulativeSequenceLengthsName};
std::string past_sequence_lengths{Defaults::PastSequenceLengthsName};
std::string block_table{Defaults::BlockTableName};

// RNNT decoder inputs
std::string targets;
std::string target_length;
std::string states_1;
std::string states_2;
} inputs;

struct Outputs {
Expand All @@ -264,6 +311,12 @@ struct Config {
std::string present_names; // When key/value pairs are combined
std::string output_cross_qk_names{"output_cross_qk_%d"};
std::string rnn_states{Defaults::RnnStatesName};

// RNNT decoder outputs
std::string outputs;
std::string prednet_lengths;
std::string states_1;
std::string states_2;
} outputs;

struct PipelineModel {
Expand Down
1 change: 1 addition & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.

#include "generators.h"
#include "streaming_asr.h"
Comment thread
nenad1002 marked this conversation as resolved.
Outdated
#include "sequences.h"
#include "models/env_utils.h"
#include "models/model.h"
Expand Down
Loading
Loading