Skip to content

[Bugfix][VoxCPM2]: Fix vectorized_gather OOB under concurrent prefill+decode batches#2903

Merged
hsliuustc0106 merged 5 commits intovllm-project:mainfrom
Sy0307:fix/voxcpm2-batch-crash-and-perf
Apr 19, 2026
Merged

[Bugfix][VoxCPM2]: Fix vectorized_gather OOB under concurrent prefill+decode batches#2903
hsliuustc0106 merged 5 commits intovllm-project:mainfrom
Sy0307:fix/voxcpm2-batch-crash-and-perf

Conversation

@Sy0307
Copy link
Copy Markdown
Contributor

@Sy0307 Sy0307 commented Apr 18, 2026

Purpose

Fixes the vectorized_gather_kernel: index out of bounds / illegal memory access crash on VoxCPM2 concurrent serving after #2803. Reproduces within 1-2 staggered rounds once requests overlap in a prefill+decode mixed batch.

Root cause

VoxCPM2TalkerForConditionalGeneration.preprocess() evicted stale per-request states at the start of every prefill using self._pending_requests as the "current batch". But _pending_requests is cleared at the end of each forward() and repopulated by vLLM's runner one request at a time — when preprocess() runs for the k-th request, it only holds the first k-1.

In a prefill+decode mixed batch (1 new prefill + N cached decodes), when the prefill is walked mid-list, the cached decode requests scheduled after it are not yet in _pending_requests, get classified as stale, and are removed from _active_states. The next forward() lazily recreates empty states (prefill_completed=False), falls into the silent-skip else branch, and drops those tokens from residual_inputs. torch.cat(residual_inputs) is now shorter than attn_metadata.slot_mapping; residual_model's PagedAttention reads/writes KV with the wrong cu_seqlens, softmax overflows to NaN on the prefill slice, NaN propagates into the CFM (LocDiT) solver, and the next indexing kernel crashes.

Smoking gun (instrumented step 59): pending=4 but sm_len=43 with 9 query_start_loc segments — vLLM scheduled 8 requests, talker only saw 4.

Why this only surfaced after #2803

The bug predates #2803. In the eager / torch.compile path, the metadata desync produces numerically wrong but finite outputs (subtly bad audio, no crash). #2803's manual CUDA Graph pins attn metadata pointers and buffer sizes at capture time; on replay the shorter batch_in against captured-size metadata becomes a hard OOB. #2803 did not introduce the bug, it converted a silent data-corruption into a loud crash.

Fix

  • Remove the _pending_requests-based eviction in preprocess().
  • Rely on on_requests_finished()_flush_deferred_cleanup() at end of forward(), which already gets correct finished_req_ids from vLLM.
  • One-shot logger.warning in _get_or_create_state() when _active_states grows past max(512, 4 * max_num_seqs) as a regression guard.

Investigated and ruled out

  • async_scheduling=true race. Looked correlated but not causal; async_scheduling=true (repo default) is stable with this fix.
  • CUDA Graph stale block_table / scheduler_metadata. Tried owned block-table buffers, MemPool(no_split=True), scheduler_metadata=None at capture, FA3 wake hook — none fixed staggered load.
  • Inductor buffer reuse. .detach().clone() on compiled outputs is a real correctness win for a different symptom (logits_indices getting INT_MAX-ish). Disabling Inductor entirely still crashes on staggered load.
  • LocDiT / CFM not batch-safe. Crash lands inside LocDiT.decoder but LocDiT was the victim, not the villain — it received NaN inputs produced upstream.
  • FA3 wake hook (use_full_cuda_graph=True, max_num_splits=32). Discarded: regressed small-batch RTF 0.106 → 0.205 (−39…−61% throughput at c=1/2/4). Not needed.

Deliberately NOT changed

Total change: 28 lines in a single file (voxcpm2_talker.py).

Test Plan

Run on 1× H20, clean origin/main worktree with this branch applied (no YAML changes):

  • Staggered 60 requests, 0.5 s gap (forces prefill+decode overlap — exact bug trigger).
  • Poisson realistic load, 80 requests, mean gap 0.4 s, short/medium/long mix.
  • Mixed-length c=8 × 20 rounds (4 short + 4 long per round).
  • Realistic load × 3 consecutive rounds.
  • Audio validity spot check (short/medium/long/zh: sample rate, duration, peak, non-zero).
  • Server log scan: grep -E "EngineDead|illegal|Traceback|NaN|out of bounds" returns 0.
  • voxcpm2_ab_bench_online.py (same script as [Perf] VoxCPM2: Speedup by manual CUDA Graph capture for scaffold/residual forward #2803) for perf comparison.

Test Result

Stability (540 requests cumulative, 0 failures)

Test Result
Staggered 60 × gap 0.5 s 60/60
Poisson realistic 80, mean 0.4 s, mixed length 80/80 (short 25/25, medium 28/28, long 27/27)
Mixed-length c=8 × 20 rounds 160/160
Realistic load × 3 rounds 240/240

Server-log scans: EngineDeadError=0, illegal memory access=0, Traceback=0.

Audio validity spot check

Prompt Duration Peak Non-zero
short 2.08 s 21270 61722
medium 9.60 s 32017 357972
long 14.88 s 28264 615253
zh 5.12 s 22993 177958

48 kHz, non-truncated, non-silent, WAV decodes cleanly.

Performance (no regression)

Metric #2803 baseline This PR
Avg RTF 0.106 0.107
Throughput c=1 0.99 req/s 1.19 req/s
Throughput c=2 1.24 req/s 1.23 req/s
Throughput c=4 1.41 req/s 1.32 req/s

c=1 slightly faster; c=2/c=4 within run-to-run noise.

Before this PR

Same worktree without the fix: staggered 60-request run crashes with vectorized_gather_kernel: index out of bounds in 1-2 rounds, engine dies, subsequent requests return HTTP 400 (EngineDeadError).

cc @linyueqian @hsliuustc0106

Root cause of the c>=2 vectorized_gather OOB / illegal memory access
reported on VoxCPM2 after PR vllm-project#2803: preprocess() used _pending_requests
(per-step prefix, cleared each forward) as if it were the full active
batch. When a new prefill was scheduled after cached decode requests in
the same batch, the decode requests were wrongly classified as stale and
removed from _active_states; the next forward then recreated empty
states, silently skipped residual_model for them, and desynchronized
attn metadata -- producing NaN on the prefill slice and eventually
crashing inside the CFM/LocDiT kernels.

State cleanup is now driven solely by on_requests_finished ->
_flush_deferred_cleanup at the end of forward().  A one-shot leak
warning in _get_or_create_state guards against a future regression.

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307 Sy0307 requested a review from hsliuustc0106 as a code owner April 18, 2026 14:48
self._results_queue: list[tuple[str, torch.Tensor | None]] = []
self._audio_queue: list[tuple[str, Any]] = []
self._deferred_cleanup_ids: set[str] = set()
self._active_state_warn_threshold = max(512, 4 * self._max_batch_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

512 hardcoded

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

hsliuustc0106

This comment was marked as outdated.

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCKER scan:

Category Result
Correctness PASS
Reliability/Safety PASS
Breaking Changes PASS
Test Coverage needs tests
Documentation PASS
Security PASS

BLOCKING ISSUES:

  1. [Test Coverage] No regression test — this bugfix lacks an automated test to prevent silent regressions.

VERDICT: REQUEST_CHANGES

hsliuustc0106

This comment was marked as duplicate.

Copy link
Copy Markdown
Collaborator

@linyueqian linyueqian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified on H20 (1× H20 141GB) against merged main a683b1dd.

Crash reproduced on main (pre-PR): 30 staggered voice-clone requests at 0.5 s gap → 0/30 succeed, Worker log shows torch.AcceleratorError: CUDA error: an illegal memory access was encountered, engine dies, all 30 return HTTP 400. Matches the PR description exactly.

With this PR (30f528f7): 30/30 and then 60/60 staggered runs succeed, output durations 2.7–5.1 s, no MAX_DECODE_STEPS, no _active_states size warn. grep -E "EngineDead|illegal memory access|vectorized_gather|out of bounds|NaN|Traceback|AssertionError" returns 0.

Diagnosis is sound: traced preprocess()'s eviction, forward()'s _pending_requests.clear() at the tail (L821), and subsequent preprocess() calls appending one at a time at L1164. When a prefill runs mid-batch, _pending_requests genuinely is a prefix-only view, any prefill_completed=True decode scheduled after it gets evicted, forward()'s _switch_to_request silently recreates a prefill_completed=False state, which then hits the skip-else at L768-771, making torch.cat(residual_inputs) shorter than the CUDA-graph-captured metadata slots. The #2803 CUDA graph just turned what used to be silent data corruption into a loud OOB, as the PR description correctly attributes.

Fix is strictly more correct than the removed logic. Approving on correctness; agree with @hsliuustc0106 that a regression test should land before merge.

Comment thread vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py Outdated
Comment thread vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py
Comment thread vllm_omni/model_executor/models/voxcpm2/voxcpm2_talker.py Outdated
@hsliuustc0106 hsliuustc0106 dismissed stale reviews from themself April 18, 2026 15:09

Consolidating reviews

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCKING:

  • Test Coverage — This bugfix needs a regression test. The manual 540-request validation is strong evidence but not reproducible in CI. A test that exercises the prefill+decode mixed batch path (even a lightweight unit test mocking the eviction condition) would prevent silent regressions.

- Extract leak-warn threshold floor to module-level constant
  _ACTIVE_STATE_LEAK_WARN_MIN = 512 (was hardcoded).
- Annotate leak warn with max_batch_size; mark one-shot by design.
- Rewrite preprocess() prefill-branch comment to explain why
  _pending_requests is unsafe (per-step prefix) and name the real
  cleanup path (on_requests_finished -> _flush_deferred_cleanup,
  fed by vLLM scheduler._free_request via gpu_ar_model_runner.py).
- Add regression tests pinning state-eviction and deferred-cleanup
  contracts (no GPU / CUDA graph / compile dependency).

Signed-off-by: Sy03 <1370724210@qq.com>
@Sy0307
Copy link
Copy Markdown
Contributor Author

Sy0307 commented Apr 18, 2026

All issues that comments mentioned have been fixed. @linyueqian @hsliuustc0106

Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good bugfix with comprehensive regression tests.

Minor suggestion: consider adding an e2e online serving test that reproduces the specific prefill+decode mixed batch pattern that triggered the crash, to complement the unit-level regression tests.

Sy0307 added 3 commits April 19, 2026 14:01
voxcpm2_talker.py imports librosa at module scope.  The lightweight
unit-test environments (simple-unit-test, diffusion-cache-backend-test,
cuda-unit-test-*) don't install librosa, so collecting the test file
fails with ModuleNotFoundError before any test runs.

Add pytest.importorskip("librosa") next to the existing torch skip so
those shards skip cleanly.  Full TTS-stack shards (which install
librosa) still exercise the regression tests.

Signed-off-by: Sy03 <1370724210@qq.com>
Complements the unit-level regression in
tests/model_executor/models/voxcpm2/test_talker_state_eviction.py
with an e2e offline_inference test that exercises the real scheduler
path: one long prompt (stays in decode across many steps) plus several
short prompts (keep entering prefill), reproducing the
"1 prefill + N cached decode" batch shape that triggered the original
crash.

Rides the existing VoxCPM2 Native AR E2E Test Buildkite shard
(gpu_1_queue, HF_TOKEN, voxcpm install) — no new shard required.

Signed-off-by: Sy03 <1370724210@qq.com>
Signed-off-by: Sy03 <1370724210@qq.com>
@hsliuustc0106 hsliuustc0106 merged commit 26edc7f into vllm-project:main Apr 19, 2026
8 checks passed
lvliang-intel pushed a commit to lvliang-intel/vllm-omni that referenced this pull request Apr 20, 2026
…+decode batches (vllm-project#2903)

Signed-off-by: Sy03 <1370724210@qq.com>
qinganrice pushed a commit to qinganrice/vllm-omni that referenced this pull request Apr 23, 2026
…+decode batches (vllm-project#2903)

Signed-off-by: Sy03 <1370724210@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants