Skip to content

[Bugfix] Derive head_dim for Qwen3 paged attention path #232

Merged
LxYuan0420 merged 1 commit intovllm-project:mainfrom
ricky-chaoju:fix/qwen3-head-dim-attr
Apr 6, 2026
Merged

[Bugfix] Derive head_dim for Qwen3 paged attention path #232
LxYuan0420 merged 1 commit intovllm-project:mainfrom
ricky-chaoju:fix/qwen3-head-dim-attr

Conversation

@ricky-chaoju
Copy link
Copy Markdown
Contributor

Summary

  • Fix AttributeError: 'Attention' object has no attribute 'head_dim' on Qwen3 models when using paged attention
  • Qwen3.5 saves self.head_dim in its Attention class, but Qwen3 does not — derive from k_proj.weight.shape[0] // n_kv_heads as fallback

Regression

Introduced by #226 (attention_sdpa.py:68). All Qwen3 models on the paged KV path crash.

Test

VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.2
python -m pytest tests/test_paged_deterministic.py -v -s -m slow

6/6 passed (Qwen3-0.6B, paged path).

@ricky-chaoju ricky-chaoju marked this pull request as ready for review April 6, 2026 03:36
q_proj_out = inner.q_proj(x)
gate = None
head_dim = inner.head_dim
head_dim = getattr(inner, "head_dim", inner.k_proj.weight.shape[0] // n_kv_heads)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall. Since is_sdpa() already guarantees k_proj exists, we can keep this even simpler (?):

head_dim = inner.k_proj.weight.shape[0] // n_kv_heads

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified to derive directly from k_proj. Updated.

@WindChimeRan
Copy link
Copy Markdown
Collaborator

We need a better CI to catch this bug earlier. @LxYuan0420 @ericcurtin

…ead_dim

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: RickyChen <rickychen@RickyChens-Mac-Pro.local>
@ricky-chaoju ricky-chaoju force-pushed the fix/qwen3-head-dim-attr branch from a2d8f95 to d423e29 Compare April 6, 2026 07:04
@ricky-chaoju ricky-chaoju reopened this Apr 6, 2026
@ricky-chaoju
Copy link
Copy Markdown
Contributor Author

I accidentally pressed close, it's updated now.

@LxYuan0420
Copy link
Copy Markdown
Collaborator

We need a better CI to catch this bug earlier. @LxYuan0420 @ericcurtin

Maybe the best fix is a better unit test to cover the missing‑head_dim SDPA case. That would have caught this early without heavy CI. Ricky can do a follow‑up PR for that.

@LxYuan0420 LxYuan0420 merged commit bbacf63 into vllm-project:main Apr 6, 2026
5 checks passed
@ricky-chaoju ricky-chaoju deleted the fix/qwen3-head-dim-attr branch April 6, 2026 09:58
@ericcurtin
Copy link
Copy Markdown
Collaborator

Running models in GitHub CI is hard as we don't really have access to GPU's can try and add more tests though Qwen3-0.6B is good for CPU inference, we just gotta make sure to keep our build times low and take advantage of parallel builds

@ricky-chaoju
Copy link
Copy Markdown
Contributor Author

ricky-chaoju commented Apr 7, 2026

Running models in GitHub CI is hard as we don't really have access to GPU's can try and add more tests though Qwen3-0.6B is good for CPU inference, we just gotta make sure to keep our build times low and take advantage of parallel builds

I currently have an idle M1 Pro (16GB) that I've been setting up as a self-hosted GitHub Actions runner. I've tested it on my fork and it's working:

If I can get permission to register a self-hosted runner on vllm-project/vllm-metal, I'm happy to contribute the compute for Metal GPU CI. It would enable running actual model inference tests that GitHub-hosted runners can't do.

The workflow adds a test-gpu job that runs on [self-hosted, macOS, ARM64], separate from the existing lint/test jobs, so it won't affect current CI if the runner goes offline.

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@LxYuan0420 @ericcurtin

My proposal:

Swap SmolLM2-135M for two models that actually match what we're building:

  • Qwen3-0.6B — GQA path
  • Qwen3.5-0.8B — linear attention path

Both with VLLM_METAL_USE_PAGED_ATTENTION=1, one prompt each. The paged path is where all the action is & No point smoke-testing the legacy path. We can use actions/cache on ~/.cache/huggingface/hub/ so model downloads only happen once.

@ricky-chaoju really appreciate the offer on the self-hosted runner! Might be simpler to host it on a cloud instance though .

Alex-ai-future pushed a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…#232)

## Summary

- Fix `AttributeError: 'Attention' object has no attribute 'head_dim'`
on Qwen3 models when using paged attention
- Qwen3.5 saves `self.head_dim` in its Attention class, but Qwen3 does
not — derive from `k_proj.weight.shape[0] // n_kv_heads` as fallback

## Regression

Introduced by vllm-project#226 (`attention_sdpa.py:68`). All Qwen3 models on the
paged KV path crash.

## Test

VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.2 \
python -m pytest tests/test_paged_deterministic.py -v -s -m slow

6/6 passed (Qwen3-0.6B, paged path).

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: RickyChen <rickychen@RickyChens-Mac-Pro.local>
Co-authored-by: RickyChen <rickychen@RickyChens-Mac-Pro.local>
LxYuan0420 pushed a commit that referenced this pull request Apr 8, 2026
## Summary

This PR enables Qwen3.5 hybrid models (SDPA + GDN layers) to run on
Metal by implementing `update_block_size_for_backend()` to unify KV
cache page sizes, adding MLA (Multi-Token Latent Attention) support, and
improving memory reporting for the MLX path.

## Problem

When trying to run Qwen3.5 (a hybrid model with SDPA + GDN layers) on
Metal, the following issues were encountered:

### 1. Hybrid Model Page Size Alignment Failure

vLLM's KV cache validation fails because SDPA page size and Mamba page
size are not divisible:

```
NotImplementedError: The page size of the layer is not divisible by the maximum page size.
Cannot unify by adjusting block_size.
```

**Test failure on main branch:**

```
FAILED tests/test_attention_dispatch.py::test_qwen35_paged_attention_hybrid - 
NotImplementedError: The page size of the layer is not divisible by the maximum page size.
```

### 2. Forced Paged Attention Causes OOM

Previously, paged attention was auto-enabled for hybrid models, which
allocates a large contiguous memory buffer that exceeds the capacity of
smaller Metal devices.

### 3. Inaccurate Memory Reporting

The MLX path reported memory based on `max_model_len`, which gave
misleadingly small values for the scheduler.

## Solution

### 1. Add `update_block_size_for_backend()` to MetalPlatform

Implements a 5-step process to unify page sizes for hybrid models:

1. **Compute attention page size per token** - Uses `MLAAttentionSpec`
for MLA models or `FullAttentionSpec` otherwise
2. **Get Mamba page size** - Queries model class for mamba state shape
and dtype
3. **Calculate block_size** - Ensures SDPA page_size &gt;= Mamba
page_size using `kernel_block_alignment_size=32`
4. **Sync mamba_block_size** - If using align mode
5. **Pad mamba_page_size** - Matches SDPA page_size exactly

**Key insight:** This is a "logical" fix for vLLM's scheduler validation
only. The Metal plugin manages KV cache internally via MLX's
`make_prompt_cache()`, independent of vLLM's calculations.

### 2. Add MLA Support

Hybrid models with MLA (e.g., DeepSeek variants) now use
`MLAAttentionSpec` for correct page size calculation:

```python
if getattr(model_config, "use_mla", False):
    attn_page_size_1_token = MLAAttentionSpec(...).page_size_bytes
else:
    attn_page_size_1_token = FullAttentionSpec(...).page_size_bytes
```

### 3. Add Early Error for Hybrid + Paged Attention

Raises a clear `ValueError` when users attempt to use paged attention
with hybrid models:

```
ValueError: Hybrid models (e.g., Qwen3.5) are not supported with paged attention on Metal. 
The Metal paged attention kernel only supports block_size in {8, 16, 32}, but hybrid models 
require block_size=160. Please remove VLLM_METAL_USE_PAGED_ATTENTION=1.
```

**Root cause:** Metal paged attention kernels only support `block_size ∈
{8, 16, 32}`, but Qwen3.5 requires `block_size=160`.

### 4. Improve Memory Reporting

Changed MLX path to report 80% of remaining Metal memory instead of one
max-length sequence:

```python
_MLX_MEMORY_BUDGET_FRACTION = 0.8
metal_limit = mx.device_info()["max_recommended_working_set_size"]
model_memory = self._get_model_memory_usage()
available = int((metal_limit - model_memory) * _MLX_MEMORY_BUDGET_FRACTION)
```

**Before:** `reporting 4.29GB for scheduler admission control (one
max-length sequence, max_model_len=2048)`

**After:** `reporting 11.20 GB for scheduler (Metal limit: 16.00 GB,
Model: 2.00 GB, Remaining: 14.00 GB, KV budget: 11.20 GB)`

### 5. Remove Auto-Enable Paged Attention for Hybrid Models

MLX's `make_prompt_cache()` handles hybrid KV cache natively. Paged
attention is now opt-in rather than forced.

## Changes

### New Files

| File | Lines | Description |
| --- | --- | --- |
| `tests/test_platform_update_block_size.py` | 576 | Comprehensive unit
tests (16 test cases) |

### Modified Files

| File | Changes | Description |
| --- | --- | --- |
| `vllm_metal/platform.py` | +172 | Add
`update_block_size_for_backend()` with MLA support |
| `vllm_metal/v1/worker.py` | +41, -23 | Improve memory reporting,
remove auto-enable paged |
| `vllm_metal/v1/model_runner.py` | +1, -1 | Fix MLX API:
`mx.device_info()` |
| `tests/test_v1_worker.py` | +22 | Update memory reporting test |

## Test Results

### Unit Tests (16/16 Pass)

```bash
$ source .venv-vllm-metal/bin/activate && python -m pytest tests/test_platform_update_block_size.py -v

tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_hybrid_model_success PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_hybrid_model_block_size_already_sufficient PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_non_hybrid_model_skipped PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_model_config_none PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_model_resolution_failure PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_get_mamba_state_shape_failure PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_get_mamba_state_dtype_failure PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_mamba_page_size_zero PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_invalid_architecture PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_block_size_increased_to_minimum PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_mamba_cache_mode_align PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_hybrid_with_paged_attention_raises_error PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_hybrid_model_uses_mla_spec PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_non_hybrid_skipped PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_with_cache_dtype[bfloat16] PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_with_cache_dtype[float16] PASSED

16 passed, 2 warnings in 2.69s
```

### Lint Checks

```bash
$ ruff check vllm_metal/ tests/
All checks passed!

$ ruff format --check vllm_metal/ tests/
5 files already formatted
```

## Usage

### Default Path (Recommended for Hybrid Models)

```bash
# No env var needed - uses MLX's native KV cache via make_prompt_cache()
vllm serve Qwen/Qwen3.5-0.8B
vllm serve Qwen/Qwen3.5-4B
vllm serve Qwen/Qwen3.5-14B
vllm serve Qwen/Qwen3.5-32B
```

### Paged Attention Path (Non-Hybrid Models Only)

```bash
# Only for non-hybrid models
VLLM_METAL_USE_PAGED_ATTENTION=1 vllm serve HuggingFaceTB/SmolLM2-135M-Instruct
```

### Unsupported Configuration (Will Raise Clear Error)

```bash
# This will now raise a helpful ValueError instead of cryptic kernel failure
VLLM_METAL_USE_PAGED_ATTENTION=1 vllm serve Qwen/Qwen3.5-0.8B
# ValueError: Hybrid models (e.g., Qwen3.5) are not supported with paged attention on Metal...
```

## Known Limitations

**Hybrid + Paged Attention is unsupported on Metal** due to kernel
limitations:

- Metal paged attention kernels only instantiate `block_size ∈ {8, 16,
32}`
- Hybrid models require `block_size=160` to satisfy vLLM's page size
divisibility validation
- Users should use the native MLX KV cache path (default) for hybrid
models

## Related Issues

- Fixes #184 (hybrid model initialization failure)
- Addresses reviewer feedback from #230 (@ricky-chaoju, @ericcurtin)
- Related to #232 (head_dim AttributeError on Qwen3)
- Related to #226 (regression fix)

## Commits

This PR includes the following commits:

```
964e9df [Metal] Fix inaccurate docstring and test comments
c87263b [Metal] Fix import order and add test for hybrid + paged error case
f7161cc [Metal] Address reviewer feedback on hybrid + paged attention
774c8bd  use new device
fb5e2d5 [Metal] Fix ruff format issues
5c7379b [Metal] Fix lint issues in MLA support changes
80b8d19 [Metal] Add MLA support to update_block_size_for_backend
733612b [Metal] Improve error handling in update_block_size_for_backend
36d8cb7 [Metal] Add unit tests for update_block_size_for_backend and improve error handling
b2973f0 Fix test for determine_available_memory single_sequence mode
a292980 Restore model_runner.py to upstream version
3f22b84 [Metal] Fix Qwen3.5 hybrid model initialization
d49d0fb [Metal] Fix hybrid model KV cache page size alignment
```

---------

Signed-off-by: Alex <alex.tech.lab@outlook.com>
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants