Skip to content

[Bugfix] Translate hybrid block_size for Metal paged attention kernel#235

Merged
ericcurtin merged 3 commits intovllm-project:mainfrom
ricky-chaoju:fix/hybrid-paged-attention-block-size
Apr 7, 2026
Merged

[Bugfix] Translate hybrid block_size for Metal paged attention kernel#235
ericcurtin merged 3 commits intovllm-project:mainfrom
ricky-chaoju:fix/hybrid-paged-attention-block-size

Conversation

@ricky-chaoju
Copy link
Copy Markdown
Contributor

Summary

  • Fix RuntimeError when running Qwen3.5 hybrid model with paged attention: Unable to load function paged_attention_..._bs544_...
  • vLLM inflates block_size to 544 to align attention pages with mamba pages in hybrid models, but the Metal kernel only has instantiations for [8, 16, 32]
  • Add block-size translation in attention_sdpa.py: reshape cache (zero-copy) and expand block tables so the kernel sees a compatible block_size

For hybrid models (e.g. Qwen3.5), vLLM sets block_size=544 to align attention page size with mamba page size. The Metal paged attention kernel is template-instantiated for block sizes [8, 16, 32] only.

The fix picks the largest kernel-supported block size that divides evenly into the cache block size (544 % 32 = 0, so kernel uses 32 with ratio=17), then:

  1. Reshapes cache: [num_blocks, 544, heads, hd] -> [num_blocks*17, 32, heads, hd] (zero-copy, same physical memory)
  2. Expands block tables: each vLLM block b becomes 17 kernel blocks [b17, ..., b17+16]

Non-hybrid models (block_size=16) are unaffected (fast path skips translation).

Test

  • pytest tests/test_attention_dispatch.py -v -m "not slow" -- 4/4 passed
  • pytest tests/test_attention_dispatch.py::test_qwen35_paged_attention_hybrid -- passed (previously RuntimeError)
  • Unit tests for _pick_kernel_block_size and _build_block_tables translation logic

… attention

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@ricky-chaoju ricky-chaoju marked this pull request as ready for review April 7, 2026 04:11
@ricky-chaoju ricky-chaoju force-pushed the fix/hybrid-paged-attention-block-size branch from 6417c9a to 251ece6 Compare April 7, 2026 07:18
Copy link
Copy Markdown
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have any unit test to cover this changes?

Comment thread vllm_metal/metal_kernel_backend/attention_sdpa.py Outdated
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@ricky-chaoju
Copy link
Copy Markdown
Contributor Author

ricky-chaoju commented Apr 7, 2026

do we have any unit test to cover this changes?

Added tests/test_block_size_translation.py with 10 test cases covering _pick_kernel_block_size and
_build_block_tables (translation, padding, exact match, indivisible error)

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@ricky-chaoju ricky-chaoju requested a review from LxYuan0420 April 7, 2026 08:42
@ericcurtin ericcurtin merged commit 0f6f76b into vllm-project:main Apr 7, 2026
5 checks passed
@ricky-chaoju ricky-chaoju deleted the fix/hybrid-paged-attention-block-size branch April 7, 2026 10:03
WindChimeRan pushed a commit that referenced this pull request Apr 8, 2026
## Summary
- Add Qwen3.5-0.8B smoke test alongside the existing Qwen3-0.6B test,
covering the hybrid SDPA + GDN linear attention paged path end-to-end
- Fix `json.load` → `json.loads(strict=False)` for both smoke tests —
responses containing newlines (e.g. Qwen3.5 output) cause `Invalid
control character` with strict parsing
- Pin model revision to `2fc06364715b967f1860aea9cf38778875588b17`
- Use longer health check timeout for Qwen3.5 (`--retry 30 --retry-delay
10`)
- Use `--max-num-seqs 1` and `VLLM_METAL_MEMORY_FRACTION=0.8` for
Qwen3.5 to fit within the CI runner's ~5GB Metal memory (hybrid models
allocate GDN linear state per slot, default 256 slots would exceed
budget)

Depends on #235 (merged) for the block_size translation fix.

---------

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Alex-ai-future pushed a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…vllm-project#235)

## Summary
- Fix RuntimeError when running Qwen3.5 hybrid model with paged
attention: `Unable to load function paged_attention_..._bs544_...`
- vLLM inflates block_size to 544 to align attention pages with mamba
pages in hybrid models, but the Metal kernel only has instantiations for
[8, 16, 32]
- Add block-size translation in attention_sdpa.py: reshape cache
(zero-copy) and expand block tables so the kernel sees a compatible
block_size

For hybrid models (e.g. Qwen3.5), vLLM sets block_size=544 to align
attention page size with mamba page size. The Metal paged attention
kernel is template-instantiated for block sizes [8, 16, 32] only.

The fix picks the largest kernel-supported block size that divides
evenly into the cache block size (544 % 32 = 0, so kernel uses 32 with
ratio=17), then:
1. Reshapes cache: [num_blocks, 544, heads, hd] -> [num_blocks*17, 32,
heads, hd] (zero-copy, same physical memory)
2. Expands block tables: each vLLM block b becomes 17 kernel blocks
[b*17, ..., b*17+16]

Non-hybrid models (block_size=16) are unaffected (fast path skips
translation).

## Test
- [x] pytest tests/test_attention_dispatch.py -v -m "not slow" -- 4/4
passed
- [x] pytest
tests/test_attention_dispatch.py::test_qwen35_paged_attention_hybrid --
passed (previously RuntimeError)
- [x] Unit tests for _pick_kernel_block_size and _build_block_tables
translation logic

---------

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Alex-ai-future pushed a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…oject#239)

## Summary
- Add Qwen3.5-0.8B smoke test alongside the existing Qwen3-0.6B test,
covering the hybrid SDPA + GDN linear attention paged path end-to-end
- Fix `json.load` → `json.loads(strict=False)` for both smoke tests —
responses containing newlines (e.g. Qwen3.5 output) cause `Invalid
control character` with strict parsing
- Pin model revision to `2fc06364715b967f1860aea9cf38778875588b17`
- Use longer health check timeout for Qwen3.5 (`--retry 30 --retry-delay
10`)
- Use `--max-num-seqs 1` and `VLLM_METAL_MEMORY_FRACTION=0.8` for
Qwen3.5 to fit within the CI runner's ~5GB Metal memory (hybrid models
allocate GDN linear state per slot, default 256 slots would exceed
budget)

Depends on vllm-project#235 (merged) for the block_size translation fix.

---------

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Alex-ai-future added a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
Explains the block-size translation mechanism (PR vllm-project#235) when users
enable paged attention for hybrid models like Qwen3.5.

The warning describes:
- Why translation is needed (vLLM requires block_size=160, Metal kernel
  only supports {8, 16, 32})
- How it works (each vLLM block splits into multiple kernel blocks,
  cache is reshaped zero-copy)
- That the default MLX path is recommended for hybrid models

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
Signed-off-by: Alex <alex.tech.lab@outlook.com>
Alex-ai-future added a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…ence estimate

Updates test expectations to match the implementation changes:
- test_hybrid_with_paged_attention_logs_warning: Verify warning is logged
  instead of ValueError (PR vllm-project#235 made hybrid + paged attention supported)
- test_determine_available_memory_single_sequence_mode: Restore to test
  one-sequence estimate (PR vllm-project#229 design) instead of 80% memory fraction

Also fixes test fixtures to include required vllm_config attribute.

Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
Signed-off-by: Alex <alex.tech.lab@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants