Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions parakeet-cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,8 @@ if (PARAKEET_BUILD_TESTS)
parakeet_apply_ccache(test-sortformer-streaming)
parakeet_register_test(test-sortformer-streaming
LABEL "fixture"
ARGS "--model" "${_qvp_sfs_q8_gguf}" "--wav" "${_qvp_diar_wav}"
REQUIRES "${_qvp_sfs_q8_gguf}" "${_qvp_diar_wav}")
ARGS "--model" "${_qvp_sfsv21_q8_gguf}" "--wav" "${_qvp_diar_wav}"
REQUIRES "${_qvp_sfsv21_q8_gguf}" "${_qvp_diar_wav}")

# v2.1 AOSC speaker-correctness regression. Asserts speaker coverage,
# re-entry slot continuity (the AOSC contract), and frame-level DER
Expand Down
20 changes: 14 additions & 6 deletions parakeet-cpp/include/parakeet/diarization.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,20 @@ struct SortformerStreamingOptions {

// === AOSC (Audio-Online Speaker Cache, Sortformer v2.1) ===
// Cache-aware streaming forward (port of NeMo's `forward_streaming_step` +
// `streaming_update` + `_compress_spkcache`). On v2.1 models (auto-detected
// from encoder shape) and spkcache_enable=true, the engine concatenates the
// speaker cache + FIFO + current chunk's pre-encode embeddings, runs the
// conformer layers over the concat, then the diariser head, before updating
// the runtime cache. This preserves speaker identity across silences far
// longer than `history_ms`. v1 models always take the legacy path.
// `streaming_update` + `_compress_spkcache`). On v2.1 models with
// spkcache_enable=true, the engine concatenates the speaker cache + FIFO +
// current chunk's pre-encode embeddings, runs the conformer layers over the
// concat, then the diariser head, before updating the runtime cache. This
// preserves speaker identity across silences far longer than `history_ms`.
// v1 and v2 models always take the legacy path.
//
// Variant detection: prefers the converter's `parakeet.model_variant` GGUF
// metadata tag (a stable per-checkpoint string, e.g.
// `sortformer-streaming-v2.1-aosc`) so a future variant that happens to
// share the v2.1 encoder shape can't silently opt into AOSC. GGUFs that
// pre-date the tag fall back to the encoder-shape heuristic: v1 has
// n_layers=18 / n_mels=80, v2.1 has n_layers=17 / n_mels=128. Re-run the
// converter after upgrading to populate the tag.
//
// `mean_sil_emb` is RUNTIME state (zeros at session start, EMA of detected
// silence frames), NOT a learned tensor -- no converter changes required.
Expand Down
27 changes: 25 additions & 2 deletions parakeet-cpp/scripts/convert-nemo-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,24 @@ def fuse_bn(weight, bias, running_mean, running_var, eps=1e-5):
return scale.astype(np.float32), shift.astype(np.float32)


def write_gguf(out: Path, cfg: dict, sd: dict, tok_bytes: bytes, quant: str):
def detect_sortformer_variant(ckpt: Path) -> str:
"""
Map a NeMo Sortformer .nemo filename to a stable variant tag the C++
loader can match against. The tag is the only thing that distinguishes
cache-aware v2.1 from architecturally-identical v1 / v2 at GGUF time
(encoder shape alone is ambiguous against future variants).
"""
stem = ckpt.stem
if "streaming_sortformer" in stem and "-v2.1" in stem:
return "sortformer-streaming-v2.1-aosc"
if "streaming_sortformer" in stem and "-v2" in stem:
return "sortformer-streaming-v2"
if "diar_sortformer" in stem and "-v1" in stem:
return "sortformer-v1"
return ""


def write_gguf(out: Path, ckpt: Path, cfg: dict, sd: dict, tok_bytes: bytes, quant: str):
model_type = detect_model_type(cfg)

enc = cfg["encoder"]
Expand Down Expand Up @@ -331,6 +348,12 @@ def write_gguf(out: Path, cfg: dict, sd: dict, tok_bytes: bytes, quant: str):
writer.add_uint32("parakeet.sortformer.tf_n_heads", int(tfe["num_attention_heads"]))
writer.add_bool ("parakeet.sortformer.tf_pre_ln", bool(tfe.get("pre_ln", False)))
writer.add_string("parakeet.sortformer.tf_hidden_act", str(tfe.get("hidden_act", "relu")))
# Variant tag (preferred over shape-based detection on the C++ side).
# Empty string = unknown checkpoint; loader falls back to encoder
# shape so older GGUFs continue to load.
variant = detect_sortformer_variant(ckpt)
if variant:
writer.add_string("parakeet.model_variant", variant)
else:
pred_hidden = int(dec["prednet"]["pred_hidden"])
pred_rnn_layers = int(dec["prednet"]["pred_rnn_layers"])
Expand Down Expand Up @@ -610,7 +633,7 @@ def main():
ckpt = ensure_ckpt(args.ckpt, args.hf_repo)
cfg, sd, tok_bytes = load_nemo(ckpt)
args.out.parent.mkdir(parents=True, exist_ok=True)
write_gguf(args.out, cfg, sd, tok_bytes, args.quant)
write_gguf(args.out, ckpt, cfg, sd, tok_bytes, args.quant)


if __name__ == "__main__":
Expand Down
11 changes: 8 additions & 3 deletions parakeet-cpp/scripts/download-all-models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# as `.nemo` archives, ready for `convert-nemo-to-gguf.py`.
#
# Idempotent: skips files that already exist on disk. Re-run any time to top up.
# Total download budget on a clean machine: ~14 GiB at the time of writing
# Total download budget on a clean machine: ~14.5 GiB at the time of writing
# (TDT v3 + TDT 1.1b + CTC 0.6b + CTC 1.1b + TDT_CTC hybrid + EOU 120M +
# Sortformer v1 + streaming Sortformer v2). Already-cached checkpoints are
# untouched.
# Sortformer v1 + streaming Sortformer v2 + streaming Sortformer v2.1).
# Already-cached checkpoints are untouched.
#
# Usage:
# ./scripts/download-all-models.sh # everything
Expand Down Expand Up @@ -99,6 +99,11 @@ if [[ "${1:-all}" != "tdt" ]]; then
echo "== nemo: diar_streaming_sortformer_4spk-v2 (4-speaker, streaming-trained, ~470 MiB)"
fetch "https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2/resolve/main/diar_streaming_sortformer_4spk-v2.nemo" \
"$NEMO_DIR/diar_streaming_sortformer_4spk-v2.nemo"

hr
echo "== nemo: diar_streaming_sortformer_4spk-v2.1 (4-speaker, streaming + AOSC fine-tune, ~470 MiB)"
fetch "https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1/resolve/main/diar_streaming_sortformer_4spk-v2.1.nemo" \
"$NEMO_DIR/diar_streaming_sortformer_4spk-v2.1.nemo"
fi

hr
Expand Down
3 changes: 3 additions & 0 deletions parakeet-cpp/src/parakeet_ctc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,9 @@ int load_from_gguf(const std::string & gguf_path,
else if (mtype_str == "sortformer") out_model.model_type = ParakeetModelType::SORTFORMER;
else out_model.model_type = ParakeetModelType::CTC;

// Optional variant tag (empty for legacy GGUFs that predate the key).
out_model.model_variant = get_str(g, "parakeet.model_variant", "");

if (out_model.model_type == ParakeetModelType::TDT) {
out_model.encoder_cfg.tdt_pred_hidden = get_u32(g, "parakeet.tdt.pred_hidden", 640);
out_model.encoder_cfg.tdt_pred_rnn_layers = get_u32(g, "parakeet.tdt.pred_rnn_layers", 2);
Expand Down
7 changes: 7 additions & 0 deletions parakeet-cpp/src/parakeet_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ struct SortformerWeights {
struct ParakeetCtcModel {
ParakeetModelType model_type = ParakeetModelType::CTC;

// Optional GGUF metadata tag (key `parakeet.model_variant`). Carries
// a stable identifier for the converted checkpoint that the engine
// can match against -- preferred over shape-based heuristics where
// two variants share the same encoder shape (e.g. sortformer-v2 vs
// sortformer-v2.1-aosc). Empty if the GGUF predates the key.
std::string model_variant;

EncoderConfig encoder_cfg;
MelConfig mel_cfg;
BpeVocab vocab;
Expand Down
31 changes: 19 additions & 12 deletions parakeet-cpp/src/parakeet_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1455,11 +1455,16 @@ void SortformerStreamSession::Impl::process_chunk(int64_t window_start_sample,

// Remap cur_full into session-stable IDs and store as the new
// baseline so the next chunk's `compute_slot_remap_` can match
// against today's emitted identity scheme.
for (auto & f : cur_full) {
f.speaker_id = remap_id(f.speaker_id);
// against today's emitted identity scheme. AOSC anchors slot
// identity through the speaker cache, so `compute_slot_remap_`
// is never consulted on that path -- skip the storage and the
// identity-remap loop entirely.
if (!cache_active) {
for (auto & f : cur_full) {
f.speaker_id = remap_id(f.speaker_id);
}
prev_chunk_full_segments = std::move(cur_full);
}
prev_chunk_full_segments = std::move(cur_full);

// VadStateChanged from speaker_probs: a frame speaks if any speaker exceeds threshold;
// the chunk speaks if any emitting-frame qualifies; dominant speaker from mean probs.
Expand Down Expand Up @@ -1656,14 +1661,16 @@ std::unique_ptr<SortformerStreamSession> Engine::diarize_start(
impl->history_samples = opts.sample_rate * opts.history_ms / 1000;
impl->ring.reserve(impl->history_samples);

// v2.1 detection (Audio-Online Speaker Cache eligibility).
// v1 sortformer-4spk-v1.q8_0: encoder.n_layers=18, preproc.n_mels=80.
// v2.1 sortformer-streaming-v2.1.q8_0: encoder.n_layers=17, preproc.n_mels=128.
// The v2.1 fine-tune is what trained the cache-aware concat-then-graph
// forward path; enabling it on v1 would just be untrained noise.
const bool model_is_v2_1 =
pimpl_->model.encoder_cfg.n_layers == 17 &&
pimpl_->model.mel_cfg.n_mels == 128;
// v2.1 detection (Audio-Online Speaker Cache eligibility). Documented
// in detail next to SortformerStreamingOptions::spkcache_enable in
// include/parakeet/diarization.h. Prefer the explicit variant tag
// emitted by the converter; fall back to encoder shape for legacy
// GGUFs that pre-date the parakeet.model_variant key.
const std::string & variant = pimpl_->model.model_variant;
const bool model_is_v2_1 = !variant.empty()
? (variant == "sortformer-streaming-v2.1-aosc")
: (pimpl_->model.encoder_cfg.n_layers == 17 &&
pimpl_->model.mel_cfg.n_mels == 128);
impl->cache_active = opts.spkcache_enable && model_is_v2_1;

if (impl->cache_active) {
Expand Down
40 changes: 33 additions & 7 deletions parakeet-cpp/src/parakeet_sortformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ namespace parakeet {

namespace {

// Score sentinels for the speaker-cache compression top-K. We use finite
// extrema (well-defined under FE_DIVBYZERO trapping FP modes that some
// host builds enable) instead of std::numeric_limits<float>::infinity()
// purely so that subsequent arithmetic on these values cannot produce
// NaNs -- they are only stored and compared with == / !=, never added.
constexpr float k_score_neg_inf = std::numeric_limits<float>::lowest();
constexpr float k_score_pos_inf = std::numeric_limits<float>::max();

// Threshold speaker probabilities into time-sorted segments.
void sf_threshold_segments(const std::vector<float> & speaker_probs,
int T_enc, int num_spks,
Expand Down Expand Up @@ -256,7 +264,7 @@ static void compute_log_pred_scores(const float * preds, int n_frames, int num_s
static void disable_low_scores(std::vector<float> & scores,
const float * preds, int n_frames, int num_spks,
int min_pos_scores_per_spk) {
const float neg_inf = -1.0e30f /* very-negative sentinel; -inf is UB with current FP flags */;
const float neg_inf = k_score_neg_inf;

// First pass: non-speech -> -inf.
for (int t = 0; t < n_frames; ++t) {
Expand Down Expand Up @@ -313,7 +321,7 @@ static void boost_topk_scores(std::vector<float> & scores,
for (int i = 0; i < k; ++i) {
const int t = idx_buf[i];
float & s = scores[(size_t) t * num_spks + spk];
if (s != -1.0e30f /* very-negative sentinel; -inf is UB with current FP flags */) {
if (s != k_score_neg_inf) {
s += boost;
}
}
Expand Down Expand Up @@ -343,6 +351,24 @@ static void compress_speaker_cache(

const int A_sil = cfg.spkcache_sil_frames_per_spk;
const int spkcache_len_per_spk = spkcache_len / num_spks - A_sil;
if (spkcache_len_per_spk <= 0) {
// Degenerate config: num_spks * A_sil >= spkcache_len leaves no
// budget for retained frames, so the boost / top-K stages would
// run with non-positive k and (for nth_element) a negative
// distance. Fall back to a silence-only cache and bail.
cache.spkcache.assign((size_t) spkcache_len * D, 0.0f);
if (cache.mean_sil_emb.size() == (size_t) D) {
for (int r = 0; r < spkcache_len; ++r) {
std::memcpy(cache.spkcache.data() + (size_t) r * D,
cache.mean_sil_emb.data(),
(size_t) D * sizeof(float));
}
}
cache.spkcache_preds.assign((size_t) spkcache_len * num_spks, 0.0f);
cache.n_rows = spkcache_len;
cache.spkcache_preds_valid = true;
return;
}
const int strong_boost = (int) std::floor((float) spkcache_len_per_spk * cfg.strong_boost_rate);
const int weak_boost = (int) std::floor((float) spkcache_len_per_spk * cfg.weak_boost_rate);
const int min_pos_per = (int) std::floor((float) spkcache_len_per_spk * cfg.min_pos_scores_rate);
Expand All @@ -360,7 +386,7 @@ static void compress_speaker_cache(
for (int t = spkcache_len; t < n_frames; ++t) {
float * s = scores.data() + (size_t) t * num_spks;
for (int i = 0; i < num_spks; ++i) {
if (s[i] != -1.0e30f /* very-negative sentinel; -inf is UB with current FP flags */) {
if (s[i] != k_score_neg_inf) {
s[i] += cfg.scores_boost_latest;
}
}
Expand All @@ -378,7 +404,7 @@ static void compress_speaker_cache(
const int n_total = n_frames + A_sil;
if (A_sil > 0) {
scores.resize((size_t) n_total * num_spks);
const float pos_inf = 1.0e30f /* very-positive sentinel; +inf is UB with current FP flags */;
const float pos_inf = k_score_pos_inf;
for (int t = n_frames; t < n_total; ++t) {
float * s = scores.data() + (size_t) t * num_spks;
for (int i = 0; i < num_spks; ++i) s[i] = pos_inf;
Expand Down Expand Up @@ -409,7 +435,7 @@ static void compress_speaker_cache(
// speaker blocks contiguous; `torch.remainder(idx, n_frames)` returns the
// frame index; our `idx % n_total` does the same.)
for (int & idx : topk) {
if (flat_score(idx) == -1.0e30f /* very-negative sentinel; -inf is UB with current FP flags */) {
if (flat_score(idx) == k_score_neg_inf) {
idx = MAX_INDEX;
}
}
Expand Down Expand Up @@ -467,7 +493,7 @@ static void compress_speaker_cache(
// `lc` is the left-context offset within the chunk region; the committed-chunk
// preds start at index `prev_spkcache_n + prev_fifo_n + lc` and span `chunk_committed`.
static void streaming_update(SortformerSpeakerCache & cache,
const float * chunk_pre_encode_lc, int chunk_committed,
const float * committed_chunk_pre_encode, int chunk_committed,
const float * preds_full,
int prev_spkcache_len_at_call, int prev_fifo_len_at_call,
int lc,
Expand All @@ -492,7 +518,7 @@ static void streaming_update(SortformerSpeakerCache & cache,
const int new_fifo_after_append = cache.n_fifo + chunk_committed;
cache.fifo.resize((size_t) new_fifo_after_append * D);
std::memcpy(cache.fifo.data() + (size_t) cache.n_fifo * D,
chunk_pre_encode_lc,
committed_chunk_pre_encode,
(size_t) chunk_committed * D * sizeof(float));
cache.fifo_preds.resize((size_t) new_fifo_after_append * num_spks);
std::memcpy(cache.fifo_preds.data() + (size_t) cache.n_fifo * num_spks,
Expand Down
54 changes: 3 additions & 51 deletions parakeet-cpp/test/test_sortformer_aosc_speakers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
// ctest fixtures behave when their fixtures aren't on disk.

#include "parakeet/engine.h"
#include "test_utils.h"

#include <algorithm>
#include <cstdio>
Expand All @@ -64,57 +65,8 @@ namespace {

constexpr double FRAME_S = 0.01; // 10 ms grid

bool file_exists(const std::string & p) {
std::ifstream f(p, std::ios::binary);
return f.good();
}

// Pulled verbatim from test_sortformer_streaming.cpp (line 37-76 of that
// file). parakeet-cpp has no shared test-util header today, so the
// helper is duplicated here on purpose; it matches how the existing
// streaming/parity tests are organised.
bool load_wav_pcm16le_mono(const std::string & path,
std::vector<float> & samples,
int & sample_rate) {
std::ifstream f(path, std::ios::binary);
if (!f) return false;
char riff[4]; f.read(riff, 4);
if (std::memcmp(riff, "RIFF", 4) != 0) return false;
f.ignore(4);
char wave[4]; f.read(wave, 4);
if (std::memcmp(wave, "WAVE", 4) != 0) return false;

bool fmt_ok = false; uint16_t channels = 0; uint16_t bits = 0; uint32_t srate = 0;
std::vector<char> data;
while (f) {
char id[4]; f.read(id, 4);
if (!f) break;
uint32_t sz = 0; f.read((char *) &sz, 4);
if (std::memcmp(id, "fmt ", 4) == 0) {
std::vector<char> hdr(sz);
f.read(hdr.data(), sz);
uint16_t fmt = *(uint16_t *) hdr.data();
channels = *(uint16_t *) (hdr.data() + 2);
srate = *(uint32_t *) (hdr.data() + 4);
bits = *(uint16_t *) (hdr.data() + 14);
if (fmt != 1 || channels != 1 || bits != 16) return false;
fmt_ok = true;
} else if (std::memcmp(id, "data", 4) == 0) {
data.resize(sz);
f.read(data.data(), sz);
break;
} else {
f.ignore(sz);
}
}
if (!fmt_ok || data.empty()) return false;
sample_rate = (int) srate;
const int n = (int) (data.size() / 2);
samples.resize(n);
const int16_t * s16 = reinterpret_cast<const int16_t *>(data.data());
for (int i = 0; i < n; ++i) samples[i] = (float) s16[i] / 32768.0f;
return true;
}
using parakeet_test::file_exists;
using parakeet_test::load_wav_pcm16le_mono;

struct RttmSeg {
double start_s;
Expand Down
Loading