Skip to content

[Bugfix] Fix V2 model runner crash on hybrid attention models (Qwen3.5)#38081

Closed
Lidang-Jiang wants to merge 4 commits into
vllm-project:mainfrom
Lidang-Jiang:fix/v2-model-runner-hybrid-attention
Closed

[Bugfix] Fix V2 model runner crash on hybrid attention models (Qwen3.5)#38081
Lidang-Jiang wants to merge 4 commits into
vllm-project:mainfrom
Lidang-Jiang:fix/v2-model-runner-hybrid-attention

Conversation

@Lidang-Jiang
Copy link
Copy Markdown
Contributor

Summary

  • Fix V2 model runner (VLLM_USE_V2_MODEL_RUNNER=1) crash on hybrid attention models like Qwen3.5
  • Root cause: _reshape_kv_cache() in attn_utils.py only handled AttentionSpec, but Qwen3.5's linear attention (Gated DeltaNet) layers produce MambaSpec, causing AssertionError at startup
  • Fix: Port MambaSpec handling from V1 model runner's _reshape_kv_cache_tensors() to V2's _reshape_kv_cache(), using the same torch.as_strided approach for state tensor reshaping

Fixes #38041

Before Fix (crash log)

V2 model runner crashes with AssertionError on Qwen3.5
$ VLLM_USE_V2_MODEL_RUNNER=1 python -m vllm.entrypoints.openai.api_server \
    --model /ssd1/models/Qwen3.5-35B-A3B --trust-remote-code \
    --tensor-parallel-size 2 --dtype float16 --max-model-len 4096

(Worker pid=73635) INFO [gpu_worker.py:272] Using V2 Model Runner
(Worker_TP0 pid=73635) INFO [model_runner.py:266] Loading model from scratch...
(Worker_TP0 pid=73635) INFO [qwen3_next.py:202] Using Triton/FLA GDN prefill kernel
...
(Worker_TP0 pid=73635) ERROR [multiproc_executor.py:949]
  File "vllm/v1/worker/gpu/attn_utils.py", line 166, in init_kv_cache
    kv_caches = _reshape_kv_cache(
  File "vllm/v1/worker/gpu/attn_utils.py", line 122, in _reshape_kv_cache
    assert isinstance(kv_cache_spec, AttentionSpec)
AssertionError

RuntimeError: Engine core initialization failed.

After Fix (successful run)

V2 model runner successfully loads and serves Qwen3.5
$ VLLM_USE_V2_MODEL_RUNNER=1 python -m vllm.entrypoints.openai.api_server \
    --model /ssd1/models/Qwen3.5-35B-A3B --trust-remote-code \
    --tensor-parallel-size 2 --dtype float16 --max-model-len 4096

(APIServer) INFO [model.py:541] Resolved architecture: Qwen3_5MoeForConditionalGeneration
(Worker) INFO [gpu_worker.py:272] Using V2 Model Runner
(Worker_TP0) INFO [qwen3_next.py:202] Using Triton/FLA GDN prefill kernel
(Worker_TP0) INFO [gpu_worker.py:436] Available KV cache memory: 36.24 GiB
(EngineCore) INFO [kv_cache_utils.py:1319] GPU KV cache size: 949,344 tokens
(APIServer) INFO: Application startup complete.
(APIServer) INFO: Uvicorn running on http://0.0.0.0:8562

$ curl http://localhost:8562/v1/chat/completions \
    -d '{"model":"/ssd1/models/Qwen3.5-35B-A3B","messages":[{"role":"user","content":"Hello"}],"max_tokens":50}'

{"id":"chatcmpl-aef24845d232b4b0","object":"chat.completion","created":1774424213,
 "model":"/ssd1/models/Qwen3.5-35B-A3B",
 "choices":[{"index":0,"message":{"role":"assistant","content":"Thinking Process:\n\n1. ..."},
 "finish_reason":"length"}],
 "usage":{"prompt_tokens":19,"total_tokens":69,"completion_tokens":50}}

Test plan

  • Verified V2 model runner starts successfully with Qwen3.5-35B-A3B (TP=2, float16)
  • Verified inference produces valid output via /v1/chat/completions
  • pre-commit checks passed (ruff check, ruff format, mypy, typos, SPDX headers)
  • Not duplicating any existing PR (verified via gh pr list --search)

Notes

  • This is AI-assisted work (Claude). All changes reviewed by human.
  • The V1 model runner (gpu_model_runner.py) already handles both AttentionSpec and MambaSpec correctly. This PR aligns V2 model runner behavior with V1.
  • Only 1 file changed: vllm/v1/worker/gpu/attn_utils.py

@mergify mergify Bot added qwen Related to Qwen models v1 bug Something isn't working labels Mar 25, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Mamba-like models by adding a new MambaSpec type for KV cache handling. The _reshape_kv_cache function has been updated to differentiate between AttentionSpec and MambaSpec for KV cache processing. A review comment suggests an optimization for calculating tensor strides within the MambaSpec handling to avoid unnecessary temporary tensor allocations and improve performance.

dtype_size = get_dtype_size(dtype)
num_element_per_page = kv_cache_spec.page_size_bytes // dtype_size
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of torch.empty(target_shape).stride() creates a temporary tensor allocation on the device just to retrieve its stride. While this might be optimized by PyTorch for small tensors, it's generally more efficient to calculate the stride directly for a C-contiguous tensor, especially in performance-critical loops within a library like VLLM. This avoids unnecessary memory allocations and potential overhead.

                    # Calculate stride for a C-contiguous tensor
                    current_stride = 1
                    strides = [1] * len(target_shape)
                    for i in range(len(target_shape) - 1, -1, -1):
                        strides[i] = current_stride
                        current_stride *= target_shape[i]
                    stride = tuple(strides)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! This pattern (torch.empty(target_shape).stride()) is intentionally kept as-is because it's directly ported from V1 model runner's _reshape_kv_cache_tensors() at gpu_model_runner.py:6629 — I wanted to maintain consistency between V1 and V2 implementations.

If we want to optimize this to avoid the temporary allocation, it would be better to update both V1 and V2 together in a follow-up PR. Happy to do that if a maintainer thinks it's worth the change.

@Lidang-Jiang
Copy link
Copy Markdown
Contributor Author

@WoosukKwon Hi, could you please add the ready label so CI can run?

This PR fixes V2 model runner crash on hybrid attention models (Qwen3.5). The root cause is that _reshape_kv_cache() in attn_utils.py only handled AttentionSpec, but Qwen3.5's linear attention (Gated DeltaNet) layers produce MambaSpec, causing AssertionError at startup.

The fix ports MambaSpec handling from V1 model runner's _reshape_kv_cache_tensors() to V2's _reshape_kv_cache(), using the same torch.as_strided approach. Only 1 file changed. Tested on A800 with Qwen3.5-35B-A3B (TP=2).

Thanks!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2026
@WoosukKwon
Copy link
Copy Markdown
Collaborator

Review from Codex:

  1. The port looks incomplete for true hybrid-attention support, not just startup. attn_utils.py:112 now reshapes MambaSpec, but MRV2 never ports the legacy post-step that rewrites attention-cache layout when attention and Mamba coexist.
    The old runner still does that in gpu_model_runner.py:6645, and KV-transfer code explicitly assumes it already happened in utils.py:504. So this fixes the AssertionError, but it may still leave hybrid KV caches in the wrong physical
    layout for connector / transfer paths.
  2. There is no regression coverage for the case this patch is touching. The commit changes core KV-cache initialization in attn_utils.py:112 and relies on padded-page as_strided behavior plus hybrid layout conventions, but it adds no test
    exercising mixed AttentionSpec + MambaSpec initialization on MRV2. That is exactly the sort of change that can look correct and still regress at runtime.

@WoosukKwon
Copy link
Copy Markdown
Collaborator

I also got this error:

(EngineCore pid=1619481)   File "/home/woosuk/workspace/vllm/vllm/v1/attention/backends/flashinfer.py", line 1634, in forward
(EngineCore pid=1619481)     trtllm_batch_decode_with_kv_cache(
(EngineCore pid=1619481)   File "/home/woosuk/workspace/vllm/.venv/lib/python3.12/site-packages/flashinfer/decode.py", line 2399, in trtllm_batch_decode_with_kv_cache
(EngineCore pid=1619481)     run_func(
(EngineCore pid=1619481)   File "python/tvm_ffi/cython/function.pxi", line 929, in tvm_ffi.core.Function.__call__
(EngineCore pid=1619481) RuntimeError: Error in function 'hashID' at /workspace/include/flashinfer/trtllm/fmha/fmhaKernels.cuh:151: The numTokensPerPage must be power of 2.

@Lidang-Jiang Lidang-Jiang force-pushed the fix/v2-model-runner-hybrid-attention branch from c05081c to cb7dcf1 Compare March 26, 2026 07:59
@Lidang-Jiang
Copy link
Copy Markdown
Contributor Author

@WoosukKwon Thanks for the thorough review! I've addressed all 3 issues in the latest commit:

1. Incomplete port — missing hybrid layout adjustment
Ported _update_hybrid_attention_mamba_layout() from V1 (gpu_model_runner.py:6648). It's now called at the end of _reshape_kv_cache() when both AttentionSpec and MambaSpec are present, adjusting attention KV cache strides from (2, num_blocks, ...) to (num_blocks, 2, ...) so blocks can be shared correctly. This matches what kv_connector/utils.py:504 expects.

2. Missing test
Added test_v2_reshape_kv_cache_hybrid_attention_mamba() regression test that covers:

  • Mixed AttentionSpec + MambaSpec KV cache initialization
  • Virtual block splitting (verifies kernel_num_blocks = num_blocks × split_factor)
  • Mamba state tensor shapes and dtypes
  • Data isolation between attention and mamba caches
  • _update_hybrid_attention_mamba_layout() stride adjustment correctness

3. FlashInfer numTokensPerPage must be power of 2
Root cause: V2's _reshape_kv_cache() was passing the raw block_size (e.g., 560 for hybrid models) directly to get_kv_cache_shape(), but FlashInfer only supports power-of-2 page sizes.

Fix: Wired in prepare_kernel_block_sizes() (already available in utils.py) to compute backend-compatible kernel block sizes, then restructured init_attn_backend() into 3 phases:

  1. Build attention groups
  2. Compute kernel_block_sizes via prepare_kernel_block_sizes()
  3. Create metadata builders with correct kernel_block_size

The kernel_block_sizes are returned from init_attn_backend() and passed through to _reshape_kv_cache() for virtual block splitting — same approach as V1.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 26, 2026

Hi @Lidang-Jiang, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@Lidang-Jiang
Copy link
Copy Markdown
Contributor Author

Fixed. The mypy failure was a type annotation mismatch:

vllm/v1/worker/gpu/model_runner.py:386: error: Argument 4 to "init_kv_cache" has incompatible type
"dict[str, type[AttentionBackend]]"; expected "dict[str, AttentionBackend]"  [arg-type]

init_attn_backend() returns dict[str, type[AttentionBackend]] (classes), but init_kv_cache() and _reshape_kv_cache() had their parameter annotated as dict[str, AttentionBackend] (instances). Updated both annotations to dict[str, type[AttentionBackend]].

Will push the fix shortly.

@Lidang-Jiang Lidang-Jiang force-pushed the fix/v2-model-runner-hybrid-attention branch from 3ac1139 to 3df9935 Compare March 26, 2026 08:34
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Lidang-Jiang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 30, 2026
@Lidang-Jiang Lidang-Jiang force-pushed the fix/v2-model-runner-hybrid-attention branch from 3df9935 to 420630c Compare April 1, 2026 12:58
@mergify mergify Bot removed the needs-rebase label Apr 1, 2026
Lidang-Jiang and others added 4 commits April 2, 2026 10:12
The V2 model runner's `_reshape_kv_cache()` only handled `AttentionSpec`
but Qwen3.5's linear attention (Gated DeltaNet) layers produce `MambaSpec`,
causing an `AssertionError` at startup.

Port MambaSpec handling from V1 model runner's `_reshape_kv_cache_tensors()`
to V2's `_reshape_kv_cache()`, using the same `torch.as_strided` approach
for state tensor reshaping.

Fixes vllm-project#38041

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
Add type: ignore[assignment] for MambaSpec branch where
list[torch.Tensor] is assigned to dict[str, torch.Tensor].
This matches V1 model runner (gpu_model_runner.py:6641).

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
…t, missing test, FlashInfer error)

Address review feedback from @WoosukKwon:

1. Port _update_hybrid_attention_mamba_layout() from V1 for correct
   stride adjustment when attention and Mamba layers coexist
2. Wire in prepare_kernel_block_sizes() for virtual block splitting,
   fixing FlashInfer "numTokensPerPage must be power of 2" error
3. Restructure init_attn_backend() into 3 phases and return
   kernel_block_sizes for downstream use
4. Add regression test for hybrid AttentionSpec + MambaSpec KV cache
   initialization with virtual block splitting

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
Change dict[str, AttentionBackend] to dict[str, type[AttentionBackend]]
in init_kv_cache() and _reshape_kv_cache() to match the actual return
type of init_attn_backend().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Lidang-Jiang <lidangjiang@gmail.com>
@Lidang-Jiang Lidang-Jiang force-pushed the fix/v2-model-runner-hybrid-attention branch from 420630c to 6d644c5 Compare April 2, 2026 02:12
@troycheng
Copy link
Copy Markdown

Are there any recent updates on this PR?

Copilot AI added a commit to Nekofish-L/vllm that referenced this pull request Apr 30, 2026
…ntion models (Qwen3.5)

- Add MambaSpec handling to _reshape_kv_cache in attn_utils.py to fix AssertionError
- Add _update_hybrid_attention_mamba_layout for hybrid attention/Mamba models
- Add virtual block splitting via kernel_block_sizes parameter
- Update init_kv_cache to compute kernel_block_sizes via prepare_kernel_block_sizes
- Pass attn_groups to init_kv_cache in model_runner.py
- Add regression test test_v2_reshape_kv_cache_hybrid_attention_mamba

Co-authored-by: GitHub Copilot

Agent-Logs-Url: https://github.com/Nekofish-L/vllm/sessions/2ce02e3b-348c-472e-a23f-c53db4db2d96

Co-authored-by: Nekofish-L <29830327+Nekofish-L@users.noreply.github.com>
@njhill njhill added the v2 label May 20, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 20, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Lidang-Jiang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 20, 2026
@Lidang-Jiang
Copy link
Copy Markdown
Contributor Author

Closing this PR because it is now superseded by the merged upstream work in #35520 and #42766.

Current main already includes the MRV2 Qwen3.5 / Mamba-hybrid support and the later KV cache kernel_block_size follow-up, so rebasing this older patch would mostly reintroduce stale/conflicting code.

I also checked a small ModelScope model (Qwen/Qwen3.5-0.8B) locally with MRV2 enabled. The local smoke could not complete on my WSL + 4GB laptop GPU environment: the source checkout cannot launch without a locally built vllm._C, and the installed wheel reaches Qwen3_5ForConditionalGeneration + Using V2 Model Runner but then stops on local UVA availability. That is an environment limitation, not the original _reshape_kv_cache AttentionSpec assertion from #38041.

Thanks for the reviews and the follow-up implementation work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working needs-rebase qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1 v2

Projects

None yet

Development

Successfully merging this pull request may close these issues.

V2 model runner crashes on Qwen3.5 mixed attention (linear + full)

4 participants