Skip to content

[Bugfix] Fix GDN FLA kernel crashes with NULL_BLOCK_ID=0 CUDA graph padding#39064

Merged
vadiklyutiy merged 6 commits intovllm-project:mainfrom
vibhavagarwal5:fix/cudagraph-null-block-crash
Apr 11, 2026
Merged

[Bugfix] Fix GDN FLA kernel crashes with NULL_BLOCK_ID=0 CUDA graph padding#39064
vadiklyutiy merged 6 commits intovllm-project:mainfrom
vibhavagarwal5:fix/cudagraph-null-block-crash

Conversation

@vibhavagarwal5
Copy link
Copy Markdown
Contributor

@vibhavagarwal5 vibhavagarwal5 commented Apr 6, 2026

Summary

Fixes #39025 — GDN (Gated Delta Network) hybrid models crash with IMA (Illegal Memory Access) errors under CUDA graphs + TP>1 on Blackwell GPUs.

Root cause: Commit bcc6f6744 (Mar 30) changed CUDA graph block table padding from fill_(-1) (PAD_SLOT_ID) to fill_(NULL_BLOCK_ID=0). The FLA SSM kernels guard padded entries with state_idx < 0, which catches -1 but not 0. This causes padded CUDA graph entries to read/write ssm_state[0], corrupting the state of real sequence 0.

Fix: Change all 5 guard comparisons across both FLA kernels to use <= 0 / > 0:

File Location Old guard New guard
fused_sigmoid_gating.py Initial state load (L114) state_idx < 0 state_idx <= 0
fused_sigmoid_gating.py Final state store (L163) final_state_idx >= 0 final_state_idx > 0
fused_recurrent.py Multi-token initial state load (L114) state_idx < 0 state_idx <= 0
fused_recurrent.py Multi-token final state store (L158) final_state_idx >= 0 final_state_idx > 0
fused_recurrent.py Packed decode kernel (L295) state_idx < 0 state_idx <= 0

Note: causal_conv1d.py and mamba_ssm.py already use == null_block_id and are not affected.

Test plan

  • Verified crash reproduces on 4x Blackwell RTX PRO 6000, TP=2, Qwen/Qwen3.5-35B-A3B
  • Confirmed --enforce-eager prevents crash (CUDA graph padding is the trigger)
  • fused_sigmoid_gating.py fix alone passes at 100 concurrency but crashes at 200
  • Full fix (both files): 10,000 requests (50 rounds × 200 concurrency, mixed prompts, ~2.8M tokens) — zero failures
  • pytest tests/models/test_gdn.py -v passes

🤖 Generated with Claude Code

Co-authored-by: Claude noreply@anthropic.com

@mergify mergify bot added nvidia v1 bug Something isn't working labels Apr 6, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the NULL_BLOCK_ID constant with -1 when filling unused block table entries in the GPU model runner to prevent illegal memory access during CUDA graph replays. I have no feedback to provide.

@vibhavagarwal5 vibhavagarwal5 force-pushed the fix/cudagraph-null-block-crash branch from 611405f to a031d72 Compare April 6, 2026 13:03
@vibhavagarwal5 vibhavagarwal5 changed the title [Bugfix] Fix CUDA graph illegal memory access with NULL_BLOCK_ID=0 under TP [Bugfix] Fix GDN SSM kernel crash with NULL_BLOCK_ID=0 padding in CUDA graphs Apr 6, 2026
@vibhavagarwal5 vibhavagarwal5 marked this pull request as ready for review April 6, 2026 15:02
Comment thread vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
Natfii added a commit to Navi-AI-Lab/nvllm that referenced this pull request Apr 7, 2026
…flags

FLA SSM kernels checked `state_idx < 0` for PAD_SLOT_ID (-1), but PR vllm-project#35431
changed CUDA graph padding to NULL_BLOCK_ID (0). Changed to `state_idx <= 0`
in fused_recurrent.py and fused_sigmoid_gating.py to handle both values.
Matches upstream PR vllm-project#39064.

Also adds --enable-auto-tool-choice --tool-call-parser hermes to run scripts
and benchmark result patterns to .gitignore.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@vibhavagarwal5 vibhavagarwal5 changed the title [Bugfix] Fix GDN SSM kernel crash with NULL_BLOCK_ID=0 padding in CUDA graphs [Bugfix] Fix GDN FLA kernel crashes with NULL_BLOCK_ID=0 CUDA graph padding Apr 7, 2026
@vibhavagarwal5
Copy link
Copy Markdown
Contributor Author

Investigation update — the fix was incomplete, now complete

After some back-and-forth trying to reproduce: the original fix (only fused_sigmoid_gating.py) was partial. Here's what we found on 4x Blackwell RTX PRO 6000 with TP=2:

The bug

Commit bcc6f6744 changed CUDA graph block table padding from fill_(-1) to fill_(NULL_BLOCK_ID=0). The FLA SSM kernels use state_idx < 0 guards, which catches -1 but not 0 — so padded entries silently read/write ssm_state[0], corrupting whichever real sequence lives there.

Why it was hard to reproduce

  • At low concurrency (≤100), the fused_sigmoid_gating.py-only fix is sufficient because the packed decode fast path (fused_recurrent.py) doesn't hit the corrupt state often enough to crash.
  • At higher concurrency (200), the packed decode path triggers the same bug in fused_recurrent.py — 3 more locations with the identical < 0 vs <= 0 guard mismatch.

Test results

Scenario Concurrency Rounds Result
No fix (baseline) 100 30 CRASH round 18
fused_sigmoid_gating.py only 100 30 PASS
fused_sigmoid_gating.py only 200 30 CRASH round 1
--enforce-eager (no CUDA graphs) 200 10 PASS
Full fix (both files) 200 50 PASS (10K reqs, ~2.8M tokens)

The latest push adds the missing fused_recurrent.py fixes (3 locations). All 5 guard changes are now in place.

@Gregory-Pereira
Copy link
Copy Markdown
Contributor

Please let me know when it's ready for a test. Happy to try to practice running feature branch's on gpus

@vibhavagarwal5
Copy link
Copy Markdown
Contributor Author

Its ready @Gregory-Pereira pls do test once.

@gaby
Copy link
Copy Markdown

gaby commented Apr 9, 2026

Ping @Gregory-Pereira @vadiklyutiy @ZJY0516 This fix is needed for qwen3.5 to stop crashing when using Tensor Parallel.

Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it is unclear why ind=0 don't appear as real ind.
Maybe the issue is that padding should fill with -1 instead of 0?

Comment thread vllm/model_executor/layers/fla/ops/fused_recurrent.py
Comment thread vllm/model_executor/layers/fla/ops/fused_recurrent.py
@github-project-automation github-project-automation bot moved this to In review in NVIDIA Apr 9, 2026
Copy link
Copy Markdown

@Alberto-Codes Alberto-Codes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Independent verification — straggler sweep

Dug into this while it was blocking TurboQuant validation on Qwen3.5-35B-A3B downstream (the crash reported against #38479 traces back here). Posting the evidence for anyone wanting to verify the scope before approving.

1. The fix is the complete straggler set. Grepped upstream/main for all guards that might have missed the NULL_BLOCK_ID=0 convention change from #35431:

$ git grep -nE 'state_idx\s*(<=?|>=?)\s*0|final_state_idx\s*(<=?|>=?)\s*0' \
    -- 'vllm/model_executor/layers/**' 'csrc/**' 'vllm/v1/**'
vllm/model_executor/layers/fla/ops/fused_recurrent.py:114:     if state_idx < 0:
vllm/model_executor/layers/fla/ops/fused_recurrent.py:158:     if final_state_idx >= 0:
vllm/model_executor/layers/fla/ops/fused_recurrent.py:295:     if state_idx < 0:
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py:114: if state_idx < 0:
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py:163: if final_state_idx >= 0:

Exactly the 5 guards this PR flips. No other kernel or site silently accepts 0 as a valid state index.

2. NULL_BLOCK_ID producers (what actually fills padded slots with 0):

  • vllm/v1/attention/backends/gdn_attn.py:372, 418 — GDN state indices
  • vllm/v1/attention/backends/mamba_attn.py:507 — mamba state indices
  • vllm/v1/worker/gpu_model_runner.py:2152 — block table tensor

3. Correct consumers (using the new convention):

  • vllm/model_executor/layers/mamba/ops/mamba_ssm.pystate_batch_idx != null_block_id (L206), token_dst_idx != null_block_id (L260), kernel takes null_block_id as a param
  • vllm/model_executor/layers/mamba/ops/causal_conv1d.py — accepts both pad_slot_id and null_block_id params, handles each in the right code path

4. Verified not-a-straggler: vllm/model_executor/layers/lightning_attn.py:622 uses slot_id == pad_slot_id, but it consumes a separate slot_idx tensor that traces back to slot_mappings.fill_(PAD_SLOT_ID) in vllm/v1/worker/gpu/block_table.py:164. Different tensor, different padding convention, still -1. Not affected by #35431.

5. Why the <= 0 guard is correct in response to @vadiklyutiy's concern: NULL_BLOCK_ID = 0 is a reserved sentinel by definition — real sequences never get assigned slot 0 after #35431. The strongest argument is that mamba_ssm.py consumes the same ssm_state_indices tensor and already treats 0 as padding. Two kernels reading the same tensor with disagreeing sentinel semantics is a bug regardless of the specific values.

6. Minor follow-up (not blocking): stale comment at vllm/v1/worker/gpu_model_runner.py:3738 references blk_table_tensor on a line that operates on slot_mapping, and mentions "mamba PAD_SLOT_ID" when mamba now uses NULL_BLOCK_ID for state indices. Code is correct, comment is misleading. Janitor PR material.

Longer-term thought: the three Triton kernel families in vLLM (mamba/ops/, fla/ops/, and lightning_attn.py) each invented their own padding-guard pattern. Parameterizing the FLA kernels with null_block_id: tl.constexpr (matching mamba_ssm's pattern) instead of <= 0 would eliminate the straggler class permanently. Happy to open a follow-up if maintainers want it — not a blocker for this PR.

cc @LucasWilkinson @MatthewBonanni since you co-authored #35431 and would know if I'm missing any producer/consumer site.

@Alberto-Codes
Copy link
Copy Markdown

@vadiklyutiy Your question is the right one to ask, and I think it's answerable without reproducing — the answer is in the parent commit that introduced the regression.

NULL_BLOCK_ID = 0 is explicitly defined as a reserved sentinel in vllm/v1/attention/backends/utils.py:45, right alongside PAD_SLOT_ID = -1. The parent commit #35431 "Use null block (0) for padded block table entries" (co-authored by @LucasWilkinson and @MatthewBonanni) intentionally reserved slot 0 as the null block — real sequences never get assigned to it. So state_idx == 0 cannot be a valid entry; any 0 a kernel sees is definitionally padding.

The strongest correctness argument is actually this: vllm/model_executor/layers/mamba/ops/mamba_ssm.py consumes the same ssm_state_indices tensor that these FLA kernels consume, and it already uses state_batch_idx != null_block_id (L206) with null_block_id=NULL_BLOCK_ID passed through. Two kernels reading the same tensor must agree on its sentinel semantics. Right now they don't — mamba_ssm treats 0 as padding, FLA treats 0 as valid. That's the straggler #39064 is fixing.

I grepped upstream/main to confirm the scope and posted the full sweep evidence in a separate comment above. The 5 guards in this PR are the complete straggler set.

Reverting to -1 padding would work semantically but would require reverting #35431 (which was itself a bugfix touching 15 files). The FLA catch-up is the smaller change.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

fix pls DCO

@vadiklyutiy vadiklyutiy added the verified Run pre-commit for new contributors without triggering other tests label Apr 10, 2026
@Gregory-Pereira
Copy link
Copy Markdown
Contributor

Sorry I had hoped to reproduce this but realized since this is blackwell specific I cannot, Im stuck with H100 and H200s. Hopefully will be getting more in future

@gaby
Copy link
Copy Markdown

gaby commented Apr 10, 2026

Sorry I had hoped to reproduce this but realized since this is blackwell specific I cannot, Im stuck with H100 and H200s. Hopefully will be getting more in future

@Gregory-Pereira I have this issue happening with H100 NVL GPU's

vibhav-agarwal and others added 2 commits April 10, 2026 15:20
…A graphs

The FLA SSM kernel (fused_sigmoid_gating_delta_rule_update_kernel) used
`state_idx < 0` to skip padded entries, but CUDA graph block table padding
uses NULL_BLOCK_ID=0. Since 0 is not < 0, the kernel processed padded entries
as real requests, reading/writing ssm_state[0] and corrupting a live cache
slot. This caused cascading memory corruption manifesting as illegal memory
access in downstream MoE GEMM kernels.

Fix: change the guard to `state_idx <= 0` (initial state load) and
`final_state_idx > 0` (final state store), making it consistent with the
causal_conv1d kernel which already guards on `== NULL_BLOCK_ID`.

Only affects GDN hybrid models (e.g. Qwen3.5-35B-A3B) under CUDA graphs
with TP>1. Pure transformer models are unaffected as they have no GDN layers.

Fixes: vllm-project#39025

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vibhav Agarwal <vibhavagarwal5@gmail.com>
Signed-off-by: vibhavagarwal5 <vibhavagarwal5@gmail.com>
The fused_recurrent kernel has the identical bug as fused_sigmoid_gating:
`state_idx < 0` guards only catch PAD_SLOT_ID=-1 but miss NULL_BLOCK_ID=0,
which is used for CUDA graph block table padding. This causes ssm_state[0]
corruption for padded entries.

Affects 3 locations in fused_recurrent.py:
- Multi-token initial state load guard (line 114)
- Multi-token final state store guard (line 158)
- Packed decode kernel guard (line 295)

Without this fix, the fused_sigmoid_gating fix alone still crashes under
higher concurrency (200 concurrent requests) because fused_recurrent is
used in the packed decode fast path.

Tested: 10,000 requests (50 rounds x 200 concurrency, mixed short/medium/long
prompts, ~2.8M tokens) with zero failures on 4x Blackwell RTX PRO 6000, TP=2.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vibhav Agarwal <vibhavagarwal5@gmail.com>
Signed-off-by: vibhavagarwal5 <vibhavagarwal5@gmail.com>
@vibhavagarwal5 vibhavagarwal5 force-pushed the fix/cudagraph-null-block-crash branch from 9b2342f to 5362db8 Compare April 10, 2026 15:20
@vibhavagarwal5
Copy link
Copy Markdown
Contributor Author

@vadiklyutiy DCO is fixed

Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

Comment thread vllm/model_executor/layers/fla/ops/fused_recurrent.py Outdated
@Gregory-Pereira
Copy link
Copy Markdown
Contributor

@Gregory-Pereira I have this issue happening with H100 NVL GPU's

Starting testing

Per review feedback from @MatthewBonanni — we should only be using
NULL_BLOCK_ID for padding block table entries now, so the comments
no longer reference PAD_SLOT_ID.

Signed-off-by: vibhavagarwal5 <vibhavagarwal5@gmail.com>
@vibhavagarwal5 vibhavagarwal5 force-pushed the fix/cudagraph-null-block-crash branch from 237d705 to 2177912 Compare April 10, 2026 18:11
@vibhavagarwal5
Copy link
Copy Markdown
Contributor Author

@MatthewBonanni fixed comments, pls approve again

@MatthewBonanni
Copy link
Copy Markdown
Collaborator

@vadiklyutiy padding with 0 is correct for the block table, see #35431

@MatthewBonanni MatthewBonanni added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 10, 2026
@Gregory-Pereira
Copy link
Copy Markdown
Contributor

Gregory-Pereira commented Apr 10, 2026

Validated on H200 as best I could @gaby. Started with relevant python tests shared in the PR

=== Ready ===
vLLM 0.19.1rc1.dev195+g21779125b from /tmp/vllm-fix/vllm/__init__.py

=== Running tests: tests/v1/attention/test_gdn_metadata_builder.py -v ===
============================= test session starts ==============================
platform linux -- Python 3.12.13, pytest-9.0.3, pluggy-1.6.0 -- /usr/bin/python3
cachedir: .pytest_cache
hypothesis profile 'default'
rootdir: /tmp/vllm-fix
configfile: pyproject.toml
plugins: typeguard-4.5.1, hypothesis-6.151.12, timeout-2.4.0, shard-0.1.2, rerunfailures-16.1, mock-3.15.1, forked-1.6.0, asyncio-1.3.0, hydra-core-1.3.2, buildkite-test-collector-0.1.9, cov-7.1.0, schemathesis-4.15.1, anyio-4.13.0
asyncio: mode=Mode.STRICT, debug=False, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collecting ... collected 8 items
Running 8 items in this shard: tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[mixed_decode_and_spec_decode], tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[pure_spec_decode], tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[pure_regular_decode], tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[spec_decode_with_real_prefill], tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[prefill_decode_and_spec_decode], tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[multiple_decodes_reclassified], tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[zero_length_padding_with_spec], tests/v1/attention/test_gdn_metadata_builder.py::test_has_initial_state_after_reclassification

tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[mixed_decode_and_spec_decode] PASSED [ 12%]
tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[pure_spec_decode] PASSED [ 25%]
tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[pure_regular_decode] PASSED [ 37%]
tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[spec_decode_with_real_prefill] PASSED [ 50%]
tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[prefill_decode_and_spec_decode] PASSED [ 62%]
tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[multiple_decodes_reclassified] PASSED [ 75%]
tests/v1/attention/test_gdn_metadata_builder.py::test_gdn_build_classification[zero_length_padding_with_spec] PASSED [ 87%]
tests/v1/attention/test_gdn_metadata_builder.py::test_has_initial_state_after_reclassification PASSED [100%]

=============================== warnings summary ===============================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

tests/v1/attention/test_gdn_metadata_builder.py: 14 warnings
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================= 8 passed, 16 warnings in 12.91s ========================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

=== Tests PASSED ===

After this server starts up fine and submited 10k requests with 200 concurrency and it looks fine:

stern "vllm-test-fix-39064" | grep -v "POST /v1/completions HTTP/1.1"

... starting after vLLM startup logs

vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:59892 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:45:57 [loggers.py:259] Engine 000: Avg prompt throughput: 51.2 tokens/s, Avg generation throughput: 0.2 tokens/s, Running: 17 reqs, Waiting: 183 reqs, GPU KV cache usage: 0.9%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:59892 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:46:07 [loggers.py:259] Engine 000: Avg prompt throughput: 10188.3 tokens/s, Avg generation throughput: 2559.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:48844 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:48844 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:46:17 [loggers.py:259] Engine 000: Avg prompt throughput: 10239.7 tokens/s, Avg generation throughput: 2559.9 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:39168 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:46:27 [loggers.py:259] Engine 000: Avg prompt throughput: 10238.1 tokens/s, Avg generation throughput: 326.7 tokens/s, Running: 200 reqs, Waiting: 0 reqs, GPU KV cache usage: 11.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:39168 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:46:37 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 2233.3 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:55788 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:55788 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:46:47 [loggers.py:259] Engine 000: Avg prompt throughput: 10238.9 tokens/s, Avg generation throughput: 2559.7 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:40298 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:40298 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:46:57 [loggers.py:259] Engine 000: Avg prompt throughput: 10239.0 tokens/s, Avg generation throughput: 2559.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:40522 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:40522 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:47:07 [loggers.py:259] Engine 000: Avg prompt throughput: 10239.2 tokens/s, Avg generation throughput: 2559.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:35442 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:35442 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:47:17 [loggers.py:259] Engine 000: Avg prompt throughput: 10240.2 tokens/s, Avg generation throughput: 2560.1 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:35768 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:47:27 [loggers.py:259] Engine 000: Avg prompt throughput: 10238.4 tokens/s, Avg generation throughput: 287.5 tokens/s, Running: 200 reqs, Waiting: 0 reqs, GPU KV cache usage: 11.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:35768 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:47:37 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 2272.3 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:34508 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:34508 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:47:47 [loggers.py:259] Engine 000: Avg prompt throughput: 10239.7 tokens/s, Avg generation throughput: 2559.9 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:56752 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:56752 - "GET /metrics HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:47:57 [loggers.py:259] Engine 000: Avg prompt throughput: 10239.3 tokens/s, Avg generation throughput: 2559.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
vllm-test-fix-39064 vllm (APIServer pid=1) INFO:     127.0.0.1:40420 - "GET /health HTTP/1.1" 200 OK
vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:48:07 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

@Gregory-Pereira
Copy link
Copy Markdown
Contributor

Ran again with 500 concurrency and got 9/10 rounds through but then hung on shared memory broadcast and crashes

vllm-test-fix-39064 vllm (APIServer pid=1) INFO 04-10 18:55:57 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 393 reqs, Waiting: 107 reqs, GPU KV cache usage: 21.7%, Prefix cache hit
rate: 0.0%
vllm-test-fix-39064 vllm (EngineCore pid=2127) INFO 04-10 18:56:46 [shm_broadcast.py:681] No available shared memory broadcast block found in 60 seconds. This typically happens when some processes are hanging or doing some time-consuming work (e.g.
compilation, weight/kv cache quantization).
vllm-test-fix-39064 vllm (EngineCore pid=2127) INFO 04-10 18:57:46 [shm_broadcast.py:681] No available shared memory broadcast block found in 60 seconds. This typically happens when some processes are hanging or doing some time-consuming work (e.g.
compilation, weight/kv cache quantization).
vllm-test-fix-39064 vllm (EngineCore pid=2127) INFO 04-10 18:58:46 [shm_broadcast.py:681] No available shared memory broadcast block found in 60 seconds. This typically happens when some processes are hanging or doing some time-consuming work (e.g.
compilation, weight/kv cache quantization).
vllm-test-fix-39064 vllm (EngineCore pid=2127) INFO 04-10 18:59:46 [shm_broadcast.py:681] No available shared memory broadcast block found in 60 seconds. This typically happens when some processes are hanging or doing some time-consuming work (e.g.
compilation, weight/kv cache quantization).
...
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448]   File "/tmp/vllm-fix/vllm/v1/engine/output_processor.py", line 85, in get
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448]     raise output
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448]   File "/tmp/vllm-fix/vllm/v1/engine/async_llm.py", line 657, in output_handler
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448]     outputs = await engine_core.get_output_async()
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448]   File "/tmp/vllm-fix/vllm/v1/engine/core_client.py", line 998, in get_output_async
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448]     raise self._format_exception(outputs) from None
(APIServer pid=1) ERROR 04-10 19:00:53 [serving.py:448] vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
(APIServer pid=1) INFO:     127.0.0.1:37998 - "GET /metrics HTTP/1.1" 200 OK
(APIServer pid=1) INFO:     Shutting down
(APIServer pid=1) INFO:     Waiting for application shutdown.
(APIServer pid=1) INFO:     Application shutdown complete.
(APIServer pid=1) INFO:     Finished server process [1]
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 2 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 3 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Seems unrelated though not seeing anything related to IMA

@vadiklyutiy vadiklyutiy self-requested a review April 10, 2026 21:00
@github-project-automation github-project-automation bot moved this from In review to Ready in NVIDIA Apr 10, 2026
@vadiklyutiy vadiklyutiy enabled auto-merge (squash) April 10, 2026 21:16
@vadiklyutiy vadiklyutiy merged commit d4cb783 into vllm-project:main Apr 11, 2026
56 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Apr 11, 2026
@vibhavagarwal5 vibhavagarwal5 deleted the fix/cudagraph-null-block-crash branch April 11, 2026 09:49
wojciech-wais pushed a commit to wojciech-wais/vllm that referenced this pull request Apr 13, 2026
…adding (vllm-project#39064)

Signed-off-by: Vibhav Agarwal <vibhavagarwal5@gmail.com>
Co-authored-by: vibhav-agarwal <vibhav.agarwal@glance.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed v1 verified Run pre-commit for new contributors without triggering other tests

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: CUDA illegal memory access with CUDA graphs enabled under high concurrency (Qwen3.5-35B-A3B, tp=2)

7 participants