Skip to content

Chatterbox optimize cpp backend multilingual model for cuda#2

Closed
Zbig9000 wants to merge 9 commits into
GustavoA1604:mainfrom
Zbig9000:chatterbox-Optimize-cpp-backend-multilingual-model-for-CUDA
Closed

Chatterbox optimize cpp backend multilingual model for cuda#2
Zbig9000 wants to merge 9 commits into
GustavoA1604:mainfrom
Zbig9000:chatterbox-Optimize-cpp-backend-multilingual-model-for-CUDA

Conversation

@Zbig9000

Copy link
Copy Markdown

# PR: QVAC-17873 [TTS GGML] Optimize cpp backend multilingual model for CUDA

Source: Zbig9000/chatterbox.cpp:chatterbox-Optimize-cpp-backend-multilingual-model-for-CUDA
Target: GustavoA1604/chatterbox.cpp:main (flows into multilingual on the next merge-down, same way the Metal patches did and the QVAC-17872 Vulkan patch does)
Companion upstream PR: ggml-org/ggml#1465 — open and mergeable, 12189/12189 test-backend-ops -b CUDA0 PASS
Companion upstream issue: ggml-org/ggml#1466 — Blackwell flash-attn config gap, the documented next step
Compare: main…chatterbox-Optimize-cpp-backend-multilingual-model-for-CUDA


## What problem does this PR solve?

Five related ggml-cuda issues that show up on chatterbox-style
workloads, plus the test/diagnostic infrastructure to keep them
fixed. Sibling work to the QVAC-17872 Vulkan PR; together they
get CUDA from 1.40× slower than Vulkan (pre-patch) to 1.13×
on long prompts (post-patch) on the same RTX 5090.

(a) Slow conv_transpose_1d kernel — HiFT vocoder bottleneck

ggml-cuda's conv_transpose_1d ships a textbook scalar kernel:
one CUDA thread per output pixel scanning the full IC × IL input
grid with a per-iteration skip conditional that only triggers on
~K/s0 of iterations. For HiFT shapes (L=303, IC=80, K=16, s0=8)
this is 67 % of total GPU time in the entire S3Gen graph
(135.98 ms across 4 calls on RTX 5090, single biggest call alone
runs 101 ms). HiFT total = 144 ms vs Vulkan's 34 ms on the same
hardware — the single biggest backend-vs-backend perf gap in any
stage of chatterbox.

(b) ggml_backend_cuda_graph_compute warmup never lands for autoregressive decode

The graph cache is gated by a 2-call warmup that requires every
property of every node to be byte-identical. Right default for
llama.cpp; wrong default for chatterbox T3 which builds a
fresh-but-topologically-identical cgraph per token with growing
K/V views
. K's ne[1] grows by 1 per token, view offsets shift,
so warmup_complete keeps resetting and the captured graph is
never used. The ~90 ms gap between T3 GPU time (~70 ms) and T3
wall (163 ms) on RTX 5090 lives here.

(c) No cross-backend per-op timing logger on CUDA

ggml-vulkan ships GGML_VK_PERF_LOGGER=1 and we already have
parsing scripts for its output baked into FINDINGS.md.
ggml-cuda has no equivalent — characterising backend-level perf
required nsys (heavyweight, NVIDIA-only, sometimes needs root for
hardware counters) or one-off manual instrumentation in a debug
build.

(d) MUL_MAT_VEC + ADD + ADD 3-op fusion missing

ggml-vulkan already fuses MUL_MAT_VEC + ADD(bias) + ADD(residual) via the MUL_MAT_ADD_ADD shader; ggml-cuda only
fuses the 2-op pattern, so the residual ADD runs as a separate
launch. At chatterbox shapes (24 layers × 2 ADDs/layer per token)
this is a measurable ~67 ms / utterance difference — the top
remaining gap to ggml-vulkan after (a) lands.

(e) FlashAttention picker has no per-shape diagnostic override

The MMA_F16 variant ggml-cuda picks for chatterbox prompt-phase
attention is ~2× slower than ggml-vulkan's flash-attn shader on the
same shape. To rule this out as a picker-choice issue (vs a
kernel-quality issue) we needed a safe way to A/B the four
variants (tile, mma, wmma, vec) without rebuilding —
including arch / shape fall-back so the dispatcher doesn't ABORT
when an unsupported variant is forced.

## How does it solve it?

All five fixes ship as a single vendored patch
(patches/ggml-cuda-chatterbox-ops.patch, 1 046 lines) applied on
top of the same pinned ggml@58c38058 the Vulkan / Metal patches
target.

Fix for (a): warp-cooperative conv_transpose_1d kernel

Modelled on the Metal-patch design (one threadgroup per output
pixel + simdgroup reduction across input channels), translated to
CUDA primitives:

  1. Grid (OL, OC, 1) × block (32, 1, 1) — one CUDA warp per
    output pixel. Block-size constant drops 256 → 32.
  2. Compute i_start = ⌈(ol − K + 1) / s0⌉, i_end = ⌊ol / s0⌋
    analytically; skip conditional eliminated entirely. Inner i
    loop iterates over at most K/s0 + 1 = 3 positions instead of
    IL = 100+.
  3. Parallelise the IC reduction across the warp (each lane
    handles a strided slice ic = tid, tid+32, …); reduce across
    the warp with __shfl_xor_sync(0xFFFFFFFFu, v, …). Thread 0
    writes the output pixel.

~110 lines diff in ggml/src/ggml-cuda/conv-transpose-1d.{cu,cuh},
no API change.

Fix for (b): GGML_CUDA_FORCE_GRAPHS=1 opt-in

Extends the early-exit branch in ggml_backend_cuda_graph_compute
with an opt-in path that always uses the captured graph and relies
on the existing cudaGraphExecUpdate (with re-instantiate-on-
failure) wiring to absorb per-call data-pointer changes:

static const bool force_graphs = (getenv("GGML_CUDA_FORCE_GRAPHS") != nullptr);
if (force_graphs) {
    if (!graph->warmup_complete) graph->warmup_complete = true;
    use_cuda_graph              = true;
    cuda_graph_update_required  = properties_changed || graph->instance == nullptr;
}

Default behaviour unchanged when the env var is unset (every
non-chatterbox consumer is byte-identical to today). ~25 line
addition to ggml/src/ggml-cuda/ggml-cuda.cu.

Fix for (c): GGML_CUDA_PERF_LOGGER=1 opt-in

Mirrors GGML_VK_PERF_LOGGER byte-for-byte in output format so
existing cross-backend grep / awk one-liners (FINDINGS.md /
FINDINGS_CUDA.md reproduction recipes) work for both backends:

----------------
CUDA Timings:
MUL_MAT q4_0 m=3072 n=383 k=1024: 24 x 241.979 us = 5807.507 us
FLASH_ATTN_EXT (64,16,411,1): 24 x 5996.571 us = 143917.704 us
…
Total time: 22480.220 us.

Implementation: ~280-line ggml_cuda_perf_logger Meyers-singleton
class with RAII scope helper around per-op dispatches,
cudaEventRecord pairs, aggregation by (op, dtype, shape) key,
sorted print at the end of each ggml_backend_cuda_graph_compute.
CUDA Graphs auto-disable when the env var is set (events would
either re-record on subsequent launches or hide inside
cudaGraphLaunch). Off by default; zero overhead in normal builds.

Fix for (d): MUL_MAT_VEC + ADD(bias) + ADD(residual) fusion

Direct port of the ggml-vulkan MUL_MAT_ADD_ADD shader fusion:

  1. New x_residual field on ggml_cuda_mm_fusion_args_* (same
    shape rules as x_bias; broadcasting rejected by the host-
    side detection logic).
  2. New 3-op pattern matcher in
    ggml_cuda_graph_evaluate_and_capture, placed above the
    existing 2-op {MUL_MAT, ADD} fusion so the greedy match
    prefers the larger fusion when both apply.
  3. mmvq.cu / mmvf.cu kernel templates extended to fold the
    residual into the matmul-vec writeback after bias and any GLU,
    matching ggml-vulkan's execution order.

Only MUL_MAT (not MUL_MAT_ID) is handled — the residual ADD
pattern doesn't appear in MoE expert routing in any model the
author has seen. ~150 line addition across
common.cuh, mmvq.cu, mmvf.cu, ggml-cuda.cu.

Fix for (e): GGML_CUDA_FATTN_KERNEL=tile|mma|wmma|vec opt-in

Wraps the existing FlashAttention picker
(ggml_cuda_get_best_fattn_kernel, renamed to _default) and
applies the env var only when the default heuristic chose
MMA_F16 (the documented A/B target). Two safety gates before
override takes effect:

  1. Per-arch availability — mirrors the picker's existing
    turing_mma_available / volta_mma_available /
    should_use_wmma_fattn checks. WMMA on Blackwell falls back
    to default with a one-shot GGML_LOG_WARN instead of
    tripping the dispatcher's GGML_ABORT for "no compiled SASS".
  2. Per-shape compatibility — VEC's compile-time templates
    only instantiate for Q.ne[1] <= 2 && K.ne[1] % FATTN_KQ_STRIDE == 0. Forcing VEC on chatterbox's prompt-
    phase or growing-KV step-decode shapes would otherwise trip
    CUDA error: invalid configuration argument. Falls back
    instead.

The empirical finding from the variant sweep on RTX 5090 + Turbo
Q4_0: default and overrides to mma / wmma / vec are all
bit-identical (only tile actually changes kernel choice
— and it's 4 % slower than MMA). Conclusion: the picker is
already optimal for chatterbox on Blackwell; the remaining 67 ms /
utterance flash-attn gap is kernel-quality intrinsic to MMA_F16,
not a picker-selection issue. Documented as upstream issue
ggml-org/ggml#1466 (Blackwell-tuned config table missing — the
picker uses Ampere's sm_80 config on Blackwell because
ampere_mma_available(cc) returns true for any cc >= 800).

~140 line addition to ggml/src/ggml-cuda/fattn.cu.

One non-obvious design decision worth calling out

ggml_cuda_perf_logger's destructor runs at static destruction time
(after main(), possibly after libcudart's own statics tear down).
The dtor flushes any pending data but does not call
cudaEventDestroy — that can crash on a torn-down driver.
Letting the OS reclaim the events is safe: the logger is opt-in
via env var, the leaked memory is process-lifetime regardless,
and the event pool is bounded. Same lifetime model as the
Vulkan-pipeline-cache flush from the QVAC-17872 PR.

## Build system changes

  • scripts/setup-ggml.sh: now iterates over a PATCHES=(…) array
    (metal + cuda), stacks them on the same pinned commit 58c38058,
    remains idempotent. Idempotency check uses
    git apply --reverse --check (more discriminating than plain
    --check — survives manually-corrupted working trees).
  • patches/README.md: refreshed to list both patches, document
    the five CUDA opt-ins, and updated to reflect the 7 modified
    files
    under src/ggml-cuda/ (was 2 in earlier rounds).
  • CMakeLists.txt: adds the new test-cuda-ops target — Apple's
    test-metal-ops-style kernel-level CPU-vs-CUDA correctness test.

ggml/ is not checked in (gitignored) — setup-ggml.sh applies
both patches to a pristine clone, same model as the Metal patch.

## Risk assessment

  • Output is not bit-identical but end-to-end audio differs at
    -58 dBFS / SNR 58.5 dB — the same kind of FP-reduction-order
    variance that Metal's simd_sum kernel introduces. Below
    perceptual tolerance.
  • The fusion fix found a real latent bug during upstream prep:
    mmvf.cu's x_residual field was missing the (sample_dst, channel_bias) offset that x_bias has. Latent for chatterbox
    (ne[2]==ne[3]==1 makes the offset zero) but exposed
    immediately by upstream's test_mul_mat_vec_fusion(batch_dims= [4,2]) test. Fixed in commit bd37318 — back-port from
    ggml-org/ggml#1465.
  • No new dependencies. All CUDA primitives used
    (__shfl_xor_sync, __restrict__, cudaEvent_t,
    cudaGraphExecUpdate) have been in CUDA since Compute
    Capability 3.0 (Kepler, 2012) at oldest, sm_70+ for the rest.
  • CUDA Toolkit 12.8 strongly recommended (vs 12.0): native
    sm_120 SASS eliminates a 27 s cold-start PTX-JIT compile.
    12.0 still works at runtime via driver JIT but is a regression
    on first-launch latency.
  • Concurrent processes / threads: cudaEventCreate,
    cudaGraphExecUpdate, the existing fusion-engine paths are
    all already single-threaded per ggml_backend_cuda_context;
    no new sync introduced.
  • Opt-in env vars are off by default: GGML_CUDA_FORCE_GRAPHS,
    GGML_CUDA_PERF_LOGGER, GGML_CUDA_FATTN_KERNEL all read once
    on first use via function-local statics. Unset = byte-identical
    to today.

## How was it tested?

End-to-end on Linux x86-64, Ryzen 9 9950X3D, RTX 5090 32 GB,
NVIDIA driver 590.48.01, CUDA Toolkit 12.8 (Blackwell sm_120
native SASS).

Validation harness — 7 test artefacts, ~54 assertions, ~5 min total

# 1. Kernel-level CPU-vs-CUDA correctness for the patched ops.
#    23 cases: conv_transpose_1d (HiFT shapes), MUL_MAT_VEC + ADD + ADD
#    fusion (Q4_0, NMSE), flash_attn_ext (5 shapes incl. step + prompt).
cmake --build build-cuda12.8 --target test-cuda-ops -j
./build-cuda12.8/test-cuda-ops
# Expected: "All CUDA op tests PASSED"

# 2. Build-system regression: setup-ggml.sh idempotency + dirty-state
#    recovery + clean re-apply on the pinned commit + modified-file
#    count matches README.
./scripts/test-build-system.sh
# Expected: "All build-system tests PASSED"

# 3. End-to-end pipeline smoke: bit-identity of FORCE_GRAPHS, 18-run
#    stress matrix (3 seeds × 3 prompts × 2 modes), perf sanity,
#    env-var combination matrix (FORCE_GRAPHS × DISABLE_FUSION ×
#    DISABLE_GRAPHS × PERF_LOGGER) with bit-identity invariants.
./scripts/test-chatterbox-cuda.sh
# Expected: "All chatterbox.cpp CUDA smoke tests PASSED"

# 4. Perf-logger smoke: output format + parsing + hot-op presence +
#    graph-disable interaction + aggregate-time bound.
./scripts/test-cuda-perf-logger.sh
# Expected: "All GGML_CUDA_PERF_LOGGER tests PASSED"

# 5. FlashAttention variant sweep + fall-back smoke.
./scripts/bench-fattn-variants.sh
# Expected: "Fastest variant by T3: <variant> (… ms, Δ=… ms / …% vs default)"

# 6. Production stability soak — 50 sequential runs, asserts
#    bit-identity (md5 match) + ≤ 5 % T3/S3Gen drift + ≤ 10 MB
#    RSS growth + ≤ 50 MB ~/.nv/ComputeCache delta. Catches
#    process-level memory leaks, GPU pool fragmentation, JIT
#    recompile storms.
./scripts/test-stability.sh ./build-cuda12.8/chatterbox 50
# Expected: "All chatterbox.cpp CUDA stability tests PASSED (50 runs)"

# 7. Diversity sweep — 5 seeds (sampler is seed-responsive),
#    5 multilingual prompts (EN/FR/ES/DE/IT — all distinct outputs),
#    3 edge durations (very short / very long / whitespace-padded).
./scripts/test-diversity.sh
# Expected: "All chatterbox.cpp CUDA diversity tests PASSED"

All 7 PASS on the post-merge head of this branch
(bench-logs-cuda/regression-phase4-r9.log).

Performance — round 1 + round 2 + round 6 cumulative

5 fresh-process runs each, median of runs 2-5 (NVIDIA driver cache
warm), Turbo Q4_0:

Stage Stock Patched Speedup
conv_transpose_1d_kernel (HiFT) 135.98 ms 3.21 ms 42×
[hift_total] 144.7 ms 30.0 ms 4.7×
S3GEN_INFER_MS 280 ms 170 ms 1.6×
Total GPU time / utterance (long prompt) 698 657 µs 614 ms (= -12 %) post-3op-fusion 1.13×
Prompt tokens Default T3 FORCE_GRAPHS=1 T3 Δ ms Δ %
19 120 ms 113 ms -7 ms -6 %
43 163 ms 150 ms -13 ms -8 %
157 384 ms 332 ms -52 ms -14 %
231 523 ms 453 ms -70 ms -13 %

End-to-end audio output is bit-identical with vs without the
FORCE_GRAPHS env var (md5sum matches across 50-run soak).

CUDA ↔ Vulkan gap on long-prompt utterance went from 1.40× pre-
patch → 1.13× post-patch
. Remaining gap is FLASH_ATTN_EXT
kernel-quality at chatterbox shapes — documented as the upstream
issue ggml-org/ggml#1466, with the diagnostic infrastructure
(GGML_CUDA_PERF_LOGGER, GGML_CUDA_FATTN_KERNEL) shipped here so
multi-Blackwell maintainers can A/B candidate Blackwell configs
without a rebuild.

CUDA Toolkit 12.0 → 12.8 cold-start delta (validated 2026-04-27)

Cold-start cost (fresh ~/.nv/ComputeCache) Toolkit 12.0 Toolkit 12.8
First-call wall 27 s 1.1 s
Cause PTX → SASS JIT (sm_120 not native) sm_120 SASS direct

12.0 still works at runtime; 12.8 is strongly recommended as
the vcpkg ggml build dependency to eliminate the 27 s tax.

Platforms not tested locally

  • Mobile RTX (laptop 4050 / 4060): mechanism is identical
    (no Blackwell-specific code path); same caveat as the Vulkan
    PR's "Android Vulkan should show 15-25 % T3 win" — needs a
    follow-up bench on at least one mobile RTX SKU before claiming
    universal applicability.
  • Jetson Orin (Ampere embedded): should benefit from (a)
    conv_transpose rewrite (the kernel is bandwidth-starved on Orin
    too) and from (b) FORCE_GRAPHS on autoregressive workloads.
    Other ops are at parity with the Ampere config.

Out of scope (documented as follow-ups, not shipped here)

The companion inputFilesForAI/qvac-17872-findings/FINDINGS_CUDA.md
captures the full investigation including:

  • § 5.1.d — Blackwell-specific FlashAttention MMA config
    table is missing (ggml_cuda_fattn_mma_get_config_blackwell
    doesn't exist; sm_120 silently uses the Ampere sm_80 config).
    Filed as upstream issue ggml-org/ggml#1466 with full code
    references and reproducer. The 67 ms / utterance gap to
    ggml-vulkan flash-attn lives here and requires either NVIDIA
    Nsight Compute hardware counters or a multi-Blackwell A/B
    sweep best done by upstream maintainers.
  • § 5.5 — cuBLAS-LT integration for Q4_0 GEMM picker
    (speculative, large; not justified at chatterbox's current
    desktop perf level — RTF 0.09 on RTX 5090).
  • § 5.6 — Mobile / embedded NVIDIA sanity bench (no code
    change, just bench data needed).
  • § 5.7 — Production stability validation (50-run soak with
    test-stability.sh) — DONE in this PR.

The companion bench-logs-cuda/ directory contains every raw log
referenced above (warm-run-*.log, cold-run-*.log,
nsys-kernels-*.csv, perf-logger-sample.log,
fattn-variants-bench.log, stability-soak50.log,
diversity-test.log, regression-{baseline-r8,phase4-r9}.log,
upstream-{pr1,pr2,pr3,bundle}-*.log).

## Companion upstream work (ggml-org/ggml)

Three of the five fixes were also prepared as a single upstream PR
to ggml-org so they benefit everyone, not just chatterbox:

  • PR #1465
    ggml-cuda: warp-cooperative conv_transpose_1d, MUL_MAT_VEC + ADD + ADD fusion, GGML_CUDA_PERF_LOGGER env var.
    5 commits, +575 / -43, 12189/12189 test-backend-ops -b CUDA0
    PASS
    including 105 new test cases added by the PR itself.
    OPEN, MERGEABLE, awaiting maintainer review. The mmvf.cu
    x_residual offset bug fix that's commit bd37318 here is
    also in this upstream PR (caught and fixed during the upstream
    test-before-change cycle).

  • Issue #1466
    ggml-cuda: flash-attn MMA picker has no Blackwell (sm_120) entry — silently uses Ampere config. Documents the remaining
    biggest gap with code references and reproducer.

GGML_CUDA_FORCE_GRAPHS and GGML_CUDA_FATTN_KERNEL are
intentionally kept chatterbox-local for now — both are niche
(growing-KV step decode and diagnostic A/B respectively) and
upstream maintainers may reasonably ask for a different framing
before they ship as default. They're available behind the env vars
in this PR, and we can split off upstream PRs later if there's
interest after #1465 lands.

The 3-op MUL_MAT_VEC + ADD(bias) + ADD(residual) fusion's mmvf.cu
kernel template was missing the (sample_dst, channel_bias) offset
for x_residual that x_bias has. Latent for chatterbox where
ne[2]==ne[3]==1 makes the offset zero, but exposed by upstream
test_mul_mat_vec_fusion(batch_dims=[4,2]) when porting the patch
to ggml-org/ggml.

Fix mirrors the existing x_bias offset path. Same fix applied
upstream in ggml-org/ggml#1465.

Made-with: Cursor
Zbig9000 pushed a commit to Zbig9000/chatterbox.cpp that referenced this pull request Apr 28, 2026
feat: expose tts-cpp as a library and make it consumable via vcpkg
Zbig9000 added a commit to Zbig9000/chatterbox.cpp that referenced this pull request May 5, 2026
…d 4)

PROGRESS.md §3.35 — T3 step-graph cache (multilingual CFG token
decode) opt-in via CHATTERBOX_T3_STEP_CACHE.  Per-(n_past,
is_uncond) std::list-LRU cache (cap 256) for build_step_graph_mtl;
saves ~3 ms per cache hit.  Single-utterance default-OFF (no
hits-to-amortise on synth GustavoA1604#1) keeps the existing path
regression-free; server-mode opt-in shows ~15 % per-pass speedup
(~256 ms / synth GustavoA1604#2 of multilingual at 136 tokens).  Tests:
src/test_t3_caches.cpp NEW with 99 checks (lifecycle + bit-exact
cold/warm logits + multi-synth amortisation timing).  Lifecycle
wired into free_t3 (CLI, both paths), Impl::free_model (Engine),
and an atexit fallback — all firing BEFORE ggml_backend_free.
Total cache test suite green: 80 + 99 + 6 + 99 = 284 / 284.
Zbig9000 added a commit to Zbig9000/chatterbox.cpp that referenced this pull request May 6, 2026
…d 4)

PROGRESS.md §3.35 — T3 step-graph cache (multilingual CFG token
decode) opt-in via CHATTERBOX_T3_STEP_CACHE.  Per-(n_past,
is_uncond) std::list-LRU cache (cap 256) for build_step_graph_mtl;
saves ~3 ms per cache hit.  Single-utterance default-OFF (no
hits-to-amortise on synth GustavoA1604#1) keeps the existing path
regression-free; server-mode opt-in shows ~15 % per-pass speedup
(~256 ms / synth GustavoA1604#2 of multilingual at 136 tokens).  Tests:
src/test_t3_caches.cpp NEW with 99 checks (lifecycle + bit-exact
cold/warm logits + multi-synth amortisation timing).  Lifecycle
wired into free_t3 (CLI, both paths), Impl::free_model (Engine),
and an atexit fallback — all firing BEFORE ggml_backend_free.
Total cache test suite green: 80 + 99 + 6 + 99 = 284 / 284.
Zbig9000 added a commit to Zbig9000/chatterbox.cpp that referenced this pull request May 6, 2026
…d 4)

PROGRESS.md §3.35 — T3 step-graph cache (multilingual CFG token
decode) opt-in via CHATTERBOX_T3_STEP_CACHE.  Per-(n_past,
is_uncond) std::list-LRU cache (cap 256) for build_step_graph_mtl;
saves ~3 ms per cache hit.  Single-utterance default-OFF (no
hits-to-amortise on synth GustavoA1604#1) keeps the existing path
regression-free; server-mode opt-in shows ~15 % per-pass speedup
(~256 ms / synth GustavoA1604#2 of multilingual at 136 tokens).  Tests:
src/test_t3_caches.cpp NEW with 99 checks (lifecycle + bit-exact
cold/warm logits + multi-synth amortisation timing).  Lifecycle
wired into free_t3 (CLI, both paths), Impl::free_model (Engine),
and an atexit fallback — all firing BEFORE ggml_backend_free.
Total cache test suite green: 80 + 99 + 6 + 99 = 284 / 284.
@GustavoA1604

Copy link
Copy Markdown
Owner

Closing as we wont target CUDA for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants