feat: manual KV-cached loop for code predictor#6
Conversation
…ctor Replace the code predictor's HuggingFace GenerationMixin.generate() call with an explicit prefill + decode loop that directly manages a DynamicCache. The code predictor runs 31 sequential forward passes per talker token to produce codebooks 2-32. The previous path went through HF's full generate() infrastructure on each invocation, adding per-step overhead from logits processors, stopping criteria, model-kwargs bookkeeping, and GenerationMixin scaffolding. The new generate_codes() method: - Creates a DynamicCache and runs a 2-token prefill - Loops 30 decode steps with explicit KV cache reuse - Uses step-specific codec_embedding layers and lm_heads directly - Handles top-k / top-p / temperature sampling via _sample_token() - Returns the same [B, num_code_groups-1] tensor the caller expects This is a prerequisite for future CUDA graph capture of the code predictor inner loop (Phase 1a), since the manual loop has deterministic control flow and no framework callbacks. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
14 tests across 3 classes: - TestSampleToken: greedy, batch, shape, top-k, temperature, top-p, valid range - TestGenerateCodes: shape, determinism, None kwargs, valid range, batch - TestGenerateCodesRegression: generate_codes() matches HF generate() under greedy decoding across multiple seeds Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Sampling parameters (do_sample, top_p, top_k, temperature) are now required — the canonical defaults live in _merge_generate_kwargs() only. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
RTX 3090, 0.6B code predictor (208M params, 5 layers, 31 decode steps): - Greedy: 79.16ms vs 91.85ms → 1.16x (13.8% faster) - Sampling: 84.86ms vs 96.98ms → 1.14x (12.5% faster) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Related upstream issues (vllm-project/vllm-omni)This PR addresses part of the work tracked in several upstream issues:
|
Standalone script that loads the real Qwen3-TTS model and generates speech through the full pipeline (talker → generate_codes() → speech decoder) to verify audio output quality. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary
GenerationMixin.generate()call with an explicit prefill + decode loop that directly manages aDynamicCachegenerate_codes()method toQwen3TTSTalkerCodePredictorModelForConditionalGenerationwith manual KV cache management_sample_token()utility for top-k/top-p/temperature sampling without HF's logits processing pipeline_sample_token(),generate_codes(), and HFgenerate()parityBenchmark
NVIDIA RTX 3090, 0.6B code predictor (208M params, 5 layers, 31 decode steps), 100 iterations after 10 warmup:
generate_codes()generate()do_sample=False)top_k=50, t=0.9)Both methods measured on the same GPU, same process, same model weights, same input tensor. The improvement comes from eliminating HF's per-step framework overhead (logits processors, stopping criteria, kwargs bookkeeping). Since the code predictor is invoked once per talker token across hundreds of tokens, the savings compound.
Reproducible via:
python benchmarks/bench_code_predictor.py --device cuda --iters 100Motivation
The code predictor runs 31 sequential forward passes per talker token (producing codebooks 2–32). Each invocation previously went through HF's full
generate()infrastructure, adding per-step overhead from:_update_model_kwargs_for_generationThe manual loop eliminates this overhead and is a prerequisite for future CUDA graph capture of the code predictor inner loop (Phase 1a), since the loop now has deterministic control flow with no framework callbacks.
Changes
_sample_token(): Module-level function implementing greedy/top-k/top-p/temperature samplinggenerate_codes(): Explicit prefill (2 tokens) + decode (30 steps) loop withDynamicCache, step-specificcodec_embeddinglayers andlm_headselectionQwen3TTSTalkerForConditionalGeneration.forward()to callgenerate_codes()instead ofgenerate()test_code_predictor_generate.py: 13 tests across 3 classes (TestSampleToken, TestGenerateCodes, TestGenerateCodesRegression)benchmarks/bench_code_predictor.py: Reproducible benchmark scriptbenchmarks/audio_quality_test.py: End-to-end audio generation testTest plan
_sample_token()(greedy, batch, shape, top-k, temperature, top-p, valid range)generate_codes()(shape, determinism, valid range, batch)generate_codes()matches HFgenerate()under greedy decoding (4 seeds)🤖 Generated with Claude Code