Skip to content

feat: manual KV-cached loop for code predictor#6

Merged
marksverdhei merged 5 commits into
mainfrom
feat/code-predictor-manual-kv-cache
Jan 30, 2026
Merged

feat: manual KV-cached loop for code predictor#6
marksverdhei merged 5 commits into
mainfrom
feat/code-predictor-manual-kv-cache

Conversation

@marksverdhei

@marksverdhei marksverdhei commented Jan 30, 2026

Copy link
Copy Markdown

Summary

  • Replace the code predictor's HuggingFace GenerationMixin.generate() call with an explicit prefill + decode loop that directly manages a DynamicCache
  • Add generate_codes() method to Qwen3TTSTalkerCodePredictorModelForConditionalGeneration with manual KV cache management
  • Add _sample_token() utility for top-k/top-p/temperature sampling without HF's logits processing pipeline
  • Add 13 unit + regression tests covering _sample_token(), generate_codes(), and HF generate() parity

Benchmark

NVIDIA RTX 3090, 0.6B code predictor (208M params, 5 layers, 31 decode steps), 100 iterations after 10 warmup:

Scenario generate_codes() HF generate() Speedup
Greedy (do_sample=False) 79.16 ms 91.85 ms 1.16x (13.8%)
Sampling (top_k=50, t=0.9) 84.86 ms 96.98 ms 1.14x (12.5%)

Both methods measured on the same GPU, same process, same model weights, same input tensor. The improvement comes from eliminating HF's per-step framework overhead (logits processors, stopping criteria, kwargs bookkeeping). Since the code predictor is invoked once per talker token across hundreds of tokens, the savings compound.

Reproducible via: python benchmarks/bench_code_predictor.py --device cuda --iters 100

Motivation

The code predictor runs 31 sequential forward passes per talker token (producing codebooks 2–32). Each invocation previously went through HF's full generate() infrastructure, adding per-step overhead from:

  • Logits processors pipeline
  • Stopping criteria evaluation
  • Model kwargs preparation and _update_model_kwargs_for_generation
  • GenerationMixin scaffolding

The manual loop eliminates this overhead and is a prerequisite for future CUDA graph capture of the code predictor inner loop (Phase 1a), since the loop now has deterministic control flow with no framework callbacks.

Changes

  • _sample_token(): Module-level function implementing greedy/top-k/top-p/temperature sampling
  • generate_codes(): Explicit prefill (2 tokens) + decode (30 steps) loop with DynamicCache, step-specific codec_embedding layers and lm_head selection
  • Updated Qwen3TTSTalkerForConditionalGeneration.forward() to call generate_codes() instead of generate()
  • test_code_predictor_generate.py: 13 tests across 3 classes (TestSampleToken, TestGenerateCodes, TestGenerateCodesRegression)
  • benchmarks/bench_code_predictor.py: Reproducible benchmark script
  • benchmarks/audio_quality_test.py: End-to-end audio generation test

Test plan

  • Unit tests for _sample_token() (greedy, batch, shape, top-k, temperature, top-p, valid range)
  • Functional tests for generate_codes() (shape, determinism, valid range, batch)
  • Regression tests verifying generate_codes() matches HF generate() under greedy decoding (4 seeds)
  • Benchmark on RTX 3090 showing 12-14% speedup
  • End-to-end TTS audio generation verified with Qwen3-TTS-12Hz-1.7B-CustomVoice (intelligible English speech output)

🤖 Generated with Claude Code

marksverdhei and others added 4 commits January 30, 2026 16:33
…ctor

Replace the code predictor's HuggingFace GenerationMixin.generate() call
with an explicit prefill + decode loop that directly manages a DynamicCache.

The code predictor runs 31 sequential forward passes per talker token to
produce codebooks 2-32. The previous path went through HF's full generate()
infrastructure on each invocation, adding per-step overhead from logits
processors, stopping criteria, model-kwargs bookkeeping, and
GenerationMixin scaffolding.

The new generate_codes() method:
- Creates a DynamicCache and runs a 2-token prefill
- Loops 30 decode steps with explicit KV cache reuse
- Uses step-specific codec_embedding layers and lm_heads directly
- Handles top-k / top-p / temperature sampling via _sample_token()
- Returns the same [B, num_code_groups-1] tensor the caller expects

This is a prerequisite for future CUDA graph capture of the code predictor
inner loop (Phase 1a), since the manual loop has deterministic control flow
and no framework callbacks.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
14 tests across 3 classes:
- TestSampleToken: greedy, batch, shape, top-k, temperature, top-p, valid range
- TestGenerateCodes: shape, determinism, None kwargs, valid range, batch
- TestGenerateCodesRegression: generate_codes() matches HF generate() under
  greedy decoding across multiple seeds

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sampling parameters (do_sample, top_p, top_k, temperature) are now
required — the canonical defaults live in _merge_generate_kwargs() only.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
RTX 3090, 0.6B code predictor (208M params, 5 layers, 31 decode steps):
- Greedy:   79.16ms vs 91.85ms → 1.16x (13.8% faster)
- Sampling: 84.86ms vs 96.98ms → 1.14x (12.5% faster)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@marksverdhei

Copy link
Copy Markdown
Author

Related upstream issues (vllm-project/vllm-omni)

This PR addresses part of the work tracked in several upstream issues:

  • #938 — Umbrella RFC for Qwen3-TTS optimization. Tracks CUDA graph acceleration, disaggregated pipeline, and streaming — all of which require replacing HF generate() with explicit loop control (which this PR does for the code predictor).

  • #976 — RFC to separate Qwen3-TTS into a 2-stage vLLM-native pipeline. Explicitly calls out that "Qwen3-TTS currently relies on HF-style generation paths" and targets making it "vLLM-native (no HF generate)." Our generate_codes() is a step toward this.

  • #1061 — Bug report: "qwen3 tts is slow on 5090, cuda Graph not enable." The code predictor runs outside vLLM's engine loop via HF generate(), so CUDA graphs can't capture it. The manual loop in this PR is a prerequisite for enabling graph capture.

  • PR #907 — "[1/N][Perf] Optimize Qwen3-TTS with vLLM's native ops" — upstream WIP to replace HF-style operations with vLLM-native ones. Our work is aligned with this direction.

  • #690 — RFC to refactor Qwen3Omni talker MTP cudagraph implementation. Related prior art for the Omni model's talker; the TTS code predictor needs the same treatment.

Standalone script that loads the real Qwen3-TTS model and generates
speech through the full pipeline (talker → generate_codes() → speech
decoder) to verify audio output quality.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@marksverdhei marksverdhei merged commit d553cb0 into main Jan 30, 2026
@marksverdhei marksverdhei deleted the feat/code-predictor-manual-kv-cache branch January 30, 2026 17:09
@marksverdhei marksverdhei restored the feat/code-predictor-manual-kv-cache branch January 30, 2026 17:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant