Skip to content

[Qwen3.5] C++ Metal kernel for hybrid SDPA + GDN paged inference(Stage C)#226

Merged
ericcurtin merged 16 commits intovllm-project:mainfrom
ricky-chaoju:gdn-cpp-kernel
Apr 4, 2026
Merged

[Qwen3.5] C++ Metal kernel for hybrid SDPA + GDN paged inference(Stage C)#226
ericcurtin merged 16 commits intovllm-project:mainfrom
ricky-chaoju:gdn-cpp-kernel

Conversation

@ricky-chaoju
Copy link
Copy Markdown
Contributor

@ricky-chaoju ricky-chaoju commented Apr 3, 2026

Summary

  • Add C++ nanobind Metal kernel for GDN linear attention with in-place paged state management
  • Enable end-to-end vllm serve for Qwen3.5 hybrid models (75% GDN + 25% SDPA)
  • Decompose mlx_lm GDN forward: conv1d (per-request) + recurrent update (batched C++ kernel)
  • Add rope validation compat patch for transformers 5.4+
  • Add M-RoPE support for mlx_vlm Qwen3.5 SDPA layers
  • Golden token test: 4/4 MATCH on Qwen3.5-0.8B (paged vs inline cache)

Benchmark (Qwen3.5-0.8B, 10 reqs, input=128, output=64, 3-run avg)

qwen35_benchmark
Prompt tok/s Gen tok/s
MLX-native (per-request loop) 257 28.6
C++ Metal kernel (batched) 998 110.9
Speedup 3.9x 3.9x

design

  • GDN recurrent state uses float32 pool to avoid float16 overflow during accumulation
  • Conv1d remains per-request (stateful), recurrent update is batched via C++ kernel
  • Auto-enables paged attention for hybrid models in worker.load_model()
  • support_hybrid_kv_cache() = True enables vLLM's hybrid KV cache manager

Ref: #194 (Stage C)

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
…models

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@ricky-chaoju ricky-chaoju marked this pull request as ready for review April 3, 2026 13:44
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@ericcurtin
Copy link
Copy Markdown
Collaborator

@ricky-chaoju @WindChimeRan I'm ok with this, but are we sure we want to do C++ ? Maybe we'd have to remove the mlx dependancy, but it might be worth it:

https://github.com/ericcurtin/inferrs

@ericcurtin
Copy link
Copy Markdown
Collaborator

dead _orig_init variable in compat.py, no Dk>256 guard in GDN kernel, golden test generate() lacks subprocess isolation

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@WindChimeRan
Copy link
Copy Markdown
Collaborator

@ericcurtin Thanks for the pointer to inferrs. I want to make sure I understand your suggestion correctly.

Are you proposing that we replace MLX with Candle as the tensor/model framework? That would mean rewriting the model files from mlx_lm to candle-transformers, and using Candle's Metal backend instead of MLX's.

BTW, I found this: https://github.com/EricLBuehler/candle-vllm haven't fully digest it yet.

Or are you suggesting something narrower - like using Rust only for the Metal kernel dispatch layer (the C++ nanobind bridge in this PR), while keeping MLX for model execution?

Want to understand the scope before we discuss tradeoffs.

@ericcurtin
Copy link
Copy Markdown
Collaborator

@ericcurtin Thanks for the pointer to inferrs. I want to make sure I understand your suggestion correctly.

Are you proposing that we replace MLX with Candle as the tensor/model framework? That would mean rewriting the model files from mlx_lm to candle-transformers, and using Candle's Metal backend instead of MLX's.

Pretty much, it would be also ok to consider doing our own kernels based on candle. The creator of MLX left Apple, so that makes me unsure about it. I'm sure it will continue as a project but that leaves a dent.

Also I think C/C++ for new code in 2026 is a bad idea, Rust is just simply better, the compiler errors the Rust compiler generates are so valuable.

But curious what @WindChimeRan @LxYuan0420 @mgoin @robertgshaw2-redhat think...

BTW, I found this: https://github.com/EricLBuehler/candle-vllm haven't fully digest it yet.

Or are you suggesting something narrower - like using Rust only for the Metal kernel dispatch layer (the C++ nanobind bridge in this PR), while keeping MLX for model execution?

Want to understand the scope before we discuss tradeoffs.

@ericcurtin
Copy link
Copy Markdown
Collaborator

Merging to keep things moving, but I'm still curious of people's thoughts on above

@ericcurtin ericcurtin merged commit a06cd65 into vllm-project:main Apr 4, 2026
5 checks passed
@LxYuan0420
Copy link
Copy Markdown
Collaborator

My view is:

  • there is nothing fundamentally wrong with the current repo, and it is still moving in a good direction
  • Rust/Candle is exciting and promising, and I would be happy to explore it
  • personally I would prefer to explore that path in a separate repo first

so if we do want to seriously explore that direction, it might be good to first agree on what the starting repo should be, so the effort stays focused?

CC: @ericcurtin

@WindChimeRan
Copy link
Copy Markdown
Collaborator

@ericcurtin @LxYuan0420

The creator of MLX left Apple, so that makes me unsure about it.

Actually I have the same feeling ... but we don't have much choice.

My take is that the tensor framework and model ecosystem choice matters more than the C++ vs Rust question. A few thoughts:

  • MLX's has lazy evaluation graph while Candle is eager. I think mlx is still better here. (Ideally, I want to have cuda graph replay or just write triton directly, but they are not compatible with apple silicone. )
  • Ollama and LM Studio have both moved to MLX backends. The ecosystem is growing.
  • From applebench, our main performance gap is vs llama.cpp (3-6x). I think switching to Candle wouldn't close that gap.

The benefits of Candle over MLX are unclear to me right now.

The applebench result is concerning to me now. vllm-metal is still too slow. I'm looking into the finegrained profilling now.

WindChimeRan added a commit to WindChimeRan/vllm-metal that referenced this pull request Apr 6, 2026
Resolve conflict in paged_ops.cpp: keep both paged_attention_primitive
(ours) and gdn_linear_attention (upstream vllm-project#226) bindings.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
WindChimeRan added a commit to WindChimeRan/vllm-metal that referenced this pull request Apr 6, 2026
Resolve conflict in paged_ops.cpp: keep both paged_attention_primitive
(ours) and gdn_linear_attention (upstream vllm-project#226) bindings.

Signed-off-by: ran <hzz5361@psu.edu>
WindChimeRan added a commit to WindChimeRan/vllm-metal that referenced this pull request Apr 6, 2026
Resolve conflict in paged_ops.cpp: keep both paged_attention_primitive
(ours) and gdn_linear_attention (upstream vllm-project#226) bindings.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: ran <hzz5361@psu.edu>
WindChimeRan added a commit to WindChimeRan/vllm-metal that referenced this pull request Apr 6, 2026
Resolve conflict in paged_ops.cpp: keep both paged_attention_primitive
(ours) and gdn_linear_attention (upstream vllm-project#226) bindings.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: ran <hzz5361@psu.edu>
@WindChimeRan
Copy link
Copy Markdown
Collaborator

@ricky-chaoju This PR modified Qwen3's hot path (e.g, vllm_metal/metal_kernel_backend/attention_sdpa.py)

I'm not sure if this is the right way to patch it. Maybe we'll have other model to hack and then attention_sdpa will be ballooned. At least we need some benchmark to show if Qwen3 has performance regression.

@ricky-chaoju
Copy link
Copy Markdown
Contributor Author

@ricky-chaoju This PR modified Qwen3's hot path (e.g, vllm_metal/metal_kernel_backend/attention_sdpa.py)

I'm not sure if this is the right way to patch it. Maybe we'll have other model to hack and then attention_sdpa will be ballooned. At least we need some benchmark to show if Qwen3 has performance regression.

The improvement comes from #226's unified varlen kernel, not from #232. Direct A/B of #232 alone is not possible because main without #232 crashes on Qwen3 (head_dim AttributeError). Micro-benchmark confirms the per-call overhead of the fix is 0.25 µs on a ~ms-scale forward pass, no regression.
benchmark_232

LxYuan0420 pushed a commit that referenced this pull request Apr 6, 2026
## Summary

- Fix `AttributeError: 'Attention' object has no attribute 'head_dim'`
on Qwen3 models when using paged attention
- Qwen3.5 saves `self.head_dim` in its Attention class, but Qwen3 does
not — derive from `k_proj.weight.shape[0] // n_kv_heads` as fallback

## Regression

Introduced by #226 (`attention_sdpa.py:68`). All Qwen3 models on the
paged KV path crash.

## Test

VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.2 \
python -m pytest tests/test_paged_deterministic.py -v -s -m slow

6/6 passed (Qwen3-0.6B, paged path).

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: RickyChen <rickychen@RickyChens-Mac-Pro.local>
Co-authored-by: RickyChen <rickychen@RickyChens-Mac-Pro.local>
Alex-ai-future pushed a commit to Alex-ai-future/vllm-metal that referenced this pull request Apr 8, 2026
…#232)

## Summary

- Fix `AttributeError: 'Attention' object has no attribute 'head_dim'`
on Qwen3 models when using paged attention
- Qwen3.5 saves `self.head_dim` in its Attention class, but Qwen3 does
not — derive from `k_proj.weight.shape[0] // n_kv_heads` as fallback

## Regression

Introduced by vllm-project#226 (`attention_sdpa.py:68`). All Qwen3 models on the
paged KV path crash.

## Test

VLLM_METAL_USE_PAGED_ATTENTION=1 VLLM_METAL_MEMORY_FRACTION=0.2 \
python -m pytest tests/test_paged_deterministic.py -v -s -m slow

6/6 passed (Qwen3-0.6B, paged path).

Signed-off-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com>
Signed-off-by: RickyChen <rickychen@RickyChens-Mac-Pro.local>
Co-authored-by: RickyChen <rickychen@RickyChens-Mac-Pro.local>
LxYuan0420 pushed a commit that referenced this pull request Apr 8, 2026
## Summary

This PR enables Qwen3.5 hybrid models (SDPA + GDN layers) to run on
Metal by implementing `update_block_size_for_backend()` to unify KV
cache page sizes, adding MLA (Multi-Token Latent Attention) support, and
improving memory reporting for the MLX path.

## Problem

When trying to run Qwen3.5 (a hybrid model with SDPA + GDN layers) on
Metal, the following issues were encountered:

### 1. Hybrid Model Page Size Alignment Failure

vLLM's KV cache validation fails because SDPA page size and Mamba page
size are not divisible:

```
NotImplementedError: The page size of the layer is not divisible by the maximum page size.
Cannot unify by adjusting block_size.
```

**Test failure on main branch:**

```
FAILED tests/test_attention_dispatch.py::test_qwen35_paged_attention_hybrid - 
NotImplementedError: The page size of the layer is not divisible by the maximum page size.
```

### 2. Forced Paged Attention Causes OOM

Previously, paged attention was auto-enabled for hybrid models, which
allocates a large contiguous memory buffer that exceeds the capacity of
smaller Metal devices.

### 3. Inaccurate Memory Reporting

The MLX path reported memory based on `max_model_len`, which gave
misleadingly small values for the scheduler.

## Solution

### 1. Add `update_block_size_for_backend()` to MetalPlatform

Implements a 5-step process to unify page sizes for hybrid models:

1. **Compute attention page size per token** - Uses `MLAAttentionSpec`
for MLA models or `FullAttentionSpec` otherwise
2. **Get Mamba page size** - Queries model class for mamba state shape
and dtype
3. **Calculate block_size** - Ensures SDPA page_size &gt;= Mamba
page_size using `kernel_block_alignment_size=32`
4. **Sync mamba_block_size** - If using align mode
5. **Pad mamba_page_size** - Matches SDPA page_size exactly

**Key insight:** This is a "logical" fix for vLLM's scheduler validation
only. The Metal plugin manages KV cache internally via MLX's
`make_prompt_cache()`, independent of vLLM's calculations.

### 2. Add MLA Support

Hybrid models with MLA (e.g., DeepSeek variants) now use
`MLAAttentionSpec` for correct page size calculation:

```python
if getattr(model_config, "use_mla", False):
    attn_page_size_1_token = MLAAttentionSpec(...).page_size_bytes
else:
    attn_page_size_1_token = FullAttentionSpec(...).page_size_bytes
```

### 3. Add Early Error for Hybrid + Paged Attention

Raises a clear `ValueError` when users attempt to use paged attention
with hybrid models:

```
ValueError: Hybrid models (e.g., Qwen3.5) are not supported with paged attention on Metal. 
The Metal paged attention kernel only supports block_size in {8, 16, 32}, but hybrid models 
require block_size=160. Please remove VLLM_METAL_USE_PAGED_ATTENTION=1.
```

**Root cause:** Metal paged attention kernels only support `block_size ∈
{8, 16, 32}`, but Qwen3.5 requires `block_size=160`.

### 4. Improve Memory Reporting

Changed MLX path to report 80% of remaining Metal memory instead of one
max-length sequence:

```python
_MLX_MEMORY_BUDGET_FRACTION = 0.8
metal_limit = mx.device_info()["max_recommended_working_set_size"]
model_memory = self._get_model_memory_usage()
available = int((metal_limit - model_memory) * _MLX_MEMORY_BUDGET_FRACTION)
```

**Before:** `reporting 4.29GB for scheduler admission control (one
max-length sequence, max_model_len=2048)`

**After:** `reporting 11.20 GB for scheduler (Metal limit: 16.00 GB,
Model: 2.00 GB, Remaining: 14.00 GB, KV budget: 11.20 GB)`

### 5. Remove Auto-Enable Paged Attention for Hybrid Models

MLX's `make_prompt_cache()` handles hybrid KV cache natively. Paged
attention is now opt-in rather than forced.

## Changes

### New Files

| File | Lines | Description |
| --- | --- | --- |
| `tests/test_platform_update_block_size.py` | 576 | Comprehensive unit
tests (16 test cases) |

### Modified Files

| File | Changes | Description |
| --- | --- | --- |
| `vllm_metal/platform.py` | +172 | Add
`update_block_size_for_backend()` with MLA support |
| `vllm_metal/v1/worker.py` | +41, -23 | Improve memory reporting,
remove auto-enable paged |
| `vllm_metal/v1/model_runner.py` | +1, -1 | Fix MLX API:
`mx.device_info()` |
| `tests/test_v1_worker.py` | +22 | Update memory reporting test |

## Test Results

### Unit Tests (16/16 Pass)

```bash
$ source .venv-vllm-metal/bin/activate && python -m pytest tests/test_platform_update_block_size.py -v

tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_hybrid_model_success PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_hybrid_model_block_size_already_sufficient PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_non_hybrid_model_skipped PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_model_config_none PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_model_resolution_failure PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_get_mamba_state_shape_failure PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_get_mamba_state_dtype_failure PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_mamba_page_size_zero PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_invalid_architecture PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_block_size_increased_to_minimum PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_mamba_cache_mode_align PASSED
tests/test_platform_update_block_size.py::TestUpdateBlockSizeForBackend::test_hybrid_with_paged_attention_raises_error PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_hybrid_model_uses_mla_spec PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_non_hybrid_skipped PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_with_cache_dtype[bfloat16] PASSED
tests/test_platform_update_block_size.py::TestMLAModels::test_mla_with_cache_dtype[float16] PASSED

16 passed, 2 warnings in 2.69s
```

### Lint Checks

```bash
$ ruff check vllm_metal/ tests/
All checks passed!

$ ruff format --check vllm_metal/ tests/
5 files already formatted
```

## Usage

### Default Path (Recommended for Hybrid Models)

```bash
# No env var needed - uses MLX's native KV cache via make_prompt_cache()
vllm serve Qwen/Qwen3.5-0.8B
vllm serve Qwen/Qwen3.5-4B
vllm serve Qwen/Qwen3.5-14B
vllm serve Qwen/Qwen3.5-32B
```

### Paged Attention Path (Non-Hybrid Models Only)

```bash
# Only for non-hybrid models
VLLM_METAL_USE_PAGED_ATTENTION=1 vllm serve HuggingFaceTB/SmolLM2-135M-Instruct
```

### Unsupported Configuration (Will Raise Clear Error)

```bash
# This will now raise a helpful ValueError instead of cryptic kernel failure
VLLM_METAL_USE_PAGED_ATTENTION=1 vllm serve Qwen/Qwen3.5-0.8B
# ValueError: Hybrid models (e.g., Qwen3.5) are not supported with paged attention on Metal...
```

## Known Limitations

**Hybrid + Paged Attention is unsupported on Metal** due to kernel
limitations:

- Metal paged attention kernels only instantiate `block_size ∈ {8, 16,
32}`
- Hybrid models require `block_size=160` to satisfy vLLM's page size
divisibility validation
- Users should use the native MLX KV cache path (default) for hybrid
models

## Related Issues

- Fixes #184 (hybrid model initialization failure)
- Addresses reviewer feedback from #230 (@ricky-chaoju, @ericcurtin)
- Related to #232 (head_dim AttributeError on Qwen3)
- Related to #226 (regression fix)

## Commits

This PR includes the following commits:

```
964e9df [Metal] Fix inaccurate docstring and test comments
c87263b [Metal] Fix import order and add test for hybrid + paged error case
f7161cc [Metal] Address reviewer feedback on hybrid + paged attention
774c8bd  use new device
fb5e2d5 [Metal] Fix ruff format issues
5c7379b [Metal] Fix lint issues in MLA support changes
80b8d19 [Metal] Add MLA support to update_block_size_for_backend
733612b [Metal] Improve error handling in update_block_size_for_backend
36d8cb7 [Metal] Add unit tests for update_block_size_for_backend and improve error handling
b2973f0 Fix test for determine_available_memory single_sequence mode
a292980 Restore model_runner.py to upstream version
3f22b84 [Metal] Fix Qwen3.5 hybrid model initialization
d49d0fb [Metal] Fix hybrid model KV cache page size alignment
```

---------

Signed-off-by: Alex <alex.tech.lab@outlook.com>
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
ericcurtin pushed a commit that referenced this pull request Apr 8, 2026
This PR is:
- To fix under-reported GDN recurrent memory sizing by accounting for
float32 recurrent state (introduced in #226).
- To add a unit test covering the corrected linear cache byte
calculation.

Notes:
- Scope is limited to hybrid linear attention accounting in
`linear_cache_bytes_per_slot()`.
- SDPA KV sizing is unchanged.

Signed-off-by: Yuan Lik Xun <lxyuan0420@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants