Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions tests/test_attention_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for attention backend dispatch.

Unit tests verify detection heuristics against real mlx_lm modules
(no model weights, just module instantiation). The slow integration
test covers the full paged attention dispatch on Qwen3.5.
"""

from __future__ import annotations

import pytest

from vllm_metal.metal_kernel_backend.attention_linear import is_linear_attention
from vllm_metal.metal_kernel_backend.attention_sdpa import is_sdpa
from vllm_metal.paged_attention_common import find_attn_attr, find_layers

# ---------------------------------------------------------------------------
# Minimal ModelArgs for real mlx_lm module instantiation (no weights needed)
# ---------------------------------------------------------------------------

_QWEN3_ARGS_KWARGS = {
"model_type": "qwen3",
"hidden_size": 64,
"num_hidden_layers": 2,
"intermediate_size": 128,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-6,
"vocab_size": 100,
"max_position_embeddings": 512,
"rope_theta": 10000.0,
"head_dim": 16,
"tie_word_embeddings": False,
}

_QWEN35_ARGS_KWARGS = {
"hidden_size": 64,
"num_hidden_layers": 4,
"intermediate_size": 128,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"rms_norm_eps": 1e-6,
"vocab_size": 100,
"max_position_embeddings": 512,
"rope_theta": 10000.0,
"head_dim": 16,
"tie_word_embeddings": False,
"full_attention_interval": 4,
}


# ---------------------------------------------------------------------------
# Detection against real mlx_lm modules
# ---------------------------------------------------------------------------


def test_qwen3_attention_detected_as_sdpa():
"""Real Qwen3 Attention module should be detected as SDPA."""
from mlx_lm.models.qwen3 import Attention, ModelArgs

args = ModelArgs(**_QWEN3_ARGS_KWARGS)
attn = Attention(args)

assert is_sdpa(attn)
assert not is_linear_attention(attn)


def test_qwen35_sdpa_layer_detected():
"""Qwen3.5 SDPA layer (every full_attention_interval-th) should have
self_attn detected as SDPA."""
from mlx_lm.models.qwen3_5 import DecoderLayer, TextModelArgs

args = TextModelArgs(**_QWEN35_ARGS_KWARGS)
# layer_idx=3 with full_attention_interval=4 → SDPA layer
layer = DecoderLayer(args, layer_idx=3)

assert find_attn_attr(layer) == "self_attn"
assert is_sdpa(layer.self_attn)
assert not is_linear_attention(layer.self_attn)


def test_qwen35_linear_layer_detected():
"""Qwen3.5 linear attention layer (GatedDeltaNet) should have
linear_attn detected as linear attention."""
from mlx_lm.models.qwen3_5 import DecoderLayer, TextModelArgs

args = TextModelArgs(**_QWEN35_ARGS_KWARGS)
# layer_idx=0 with full_attention_interval=4 → linear attention layer
layer = DecoderLayer(args, layer_idx=0)

assert find_attn_attr(layer) == "linear_attn"
assert is_linear_attention(layer.linear_attn)
assert not is_sdpa(layer.linear_attn)


def test_find_layers_on_qwen3_model():
"""find_layers should return the layer list from a real Qwen3 Model."""
from mlx_lm.models.qwen3 import Model, ModelArgs

args = ModelArgs(**_QWEN3_ARGS_KWARGS)
model = Model(args)
layers = find_layers(model)

assert len(layers) == args.num_hidden_layers
assert find_attn_attr(layers[0]) == "self_attn"


# ---------------------------------------------------------------------------
# Slow integration test
# ---------------------------------------------------------------------------


@pytest.mark.slow
@pytest.mark.xfail(
raises=NotImplementedError,
reason="Linear attention (GatedDeltaNet) Metal kernel not yet implemented",
strict=True,
)
def test_qwen35_paged_attention_raises_on_linear_layers():
"""Loading Qwen/Qwen3.5-0.8B with paged attention raises
NotImplementedError on the linear attention layers."""
from vllm import LLM, SamplingParams

with pytest.MonkeyPatch.context() as mp:
mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1")
mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2")

llm = LLM(model="Qwen/Qwen3.5-0.8B", max_model_len=512, max_num_seqs=1)
sp = SamplingParams(temperature=0, max_tokens=5)
llm.generate(["Hello"], sp)
68 changes: 68 additions & 0 deletions vllm_metal/metal_kernel_backend/attention_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
"""Linear attention (Gated DeltaNet) on Metal — NOT YET IMPLEMENTED.

Targets models like Qwen/Qwen3.5-0.8B (mlx_lm module: ``qwen3_next``) that use hybrid architectures
with a mix of SDPA and linear attention layers. In mlx_lm, the linear
attention module is ``Qwen3NextGatedDeltaNet`` and lives on
``layer.linear_attn`` (as opposed to ``layer.self_attn`` for SDPA layers).

Detection heuristic: the module has ``conv1d`` (1-D convolution before
attention) and no ``q_proj`` (which would indicate SDPA). This works across
all known implementations:
- mlx_lm ``qwen3_next``: ``in_proj_qkvz`` + ``conv1d``
- mlx_lm ``qwen3_5``: ``in_proj_qkv`` + ``conv1d``
- mlx_vlm ``qwen3_5``: ``in_proj_qkv`` + ``conv1d``

All operations use MLX arrays end-to-end — no PyTorch MPS bridge.
"""

from __future__ import annotations

import mlx.core as mx
import mlx.nn as nn

from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache
from vllm_metal.paged_attention_common import PagedAttentionContext


def is_linear_attention(module: nn.Module) -> bool:
"""Return True if *module* is a linear attention layer (e.g. GatedDeltaNet).

Checks for ``conv1d`` (present in all known GatedDeltaNet variants) and
the absence of ``q_proj`` (which would indicate SDPA).
"""
return hasattr(module, "conv1d") and not hasattr(module, "q_proj")


def linear_attention_forward(
inner: nn.Module,
x: mx.array,
ctx: PagedAttentionContext,
kv_cache: MetalPagedKVCache,
layer_idx: int,
) -> mx.array:
"""Linear attention forward pass — not yet implemented.

This is a placeholder for future implementation. Contributing guide:

1. Linear attention (GatedDeltaNet) uses a recurrent state instead of
a standard KV cache. The cache layout will differ from SDPA — you
will likely need a separate cache class or a per-layer cache spec.

2. There is no softmax — the kernel computes gated delta updates:
``gated_delta_update(q, k, v, a, b, A_log, dt_bias, state, ...)``.

3. The model uses ``conv1d`` over concatenated Q/K/V before the
attention computation. This stateful convolution needs its own
cache slot (``MambaCache`` in mlx_lm).

4. Qwen3.5 is a hybrid model: SDPA layers (every ``full_attention_interval``-th)
coexist with linear layers. The patching mechanism in
``paged_attention.py`` needs to handle both ``self_attn`` and
``linear_attn`` attributes on the same model.
"""
raise NotImplementedError(
f"Linear attention (GatedDeltaNet) is not yet implemented for Metal paged "
f"attention. Module: {type(inner).__name__}, layer: {layer_idx}. "
f"See attention_linear.py for contributing guide."
)
146 changes: 146 additions & 0 deletions vllm_metal/metal_kernel_backend/attention_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
"""Scaled dot-product attention (SDPA) on Metal.

Supports MHA, GQA, and MQA as variants of the same kernel — the head ratio
between ``n_heads`` (queries) and ``n_kv_heads`` (keys/values) is handled
transparently by the Metal paged attention kernel.

Handles models whose attention module exposes:
- ``q_proj``, ``k_proj``, ``v_proj``, ``o_proj`` linear projections
- ``rope`` for rotary position embeddings
- ``n_heads``, ``n_kv_heads`` head counts
- Optionally ``q_norm``, ``k_norm`` (Qwen3 per-head RMSNorm before RoPE)

Covers: Qwen3, Llama, Mistral, and other standard transformer architectures.

All operations use MLX arrays end-to-end — no PyTorch MPS bridge.
"""

from __future__ import annotations

import mlx.core as mx
import mlx.nn as nn

from vllm_metal.metal import get_ops
from vllm_metal.metal_kernel_backend.cache import MetalPagedKVCache
from vllm_metal.metal_kernel_backend.packed_prefill_compat import (
apply_packed_rope,
)
from vllm_metal.paged_attention_common import PagedAttentionContext


def is_sdpa(module: nn.Module) -> bool:
"""Return True if *module* is an SDPA attention layer (MHA, GQA, or MQA)."""
return (
hasattr(module, "q_proj")
and hasattr(module, "k_proj")
and hasattr(module, "v_proj")
and hasattr(module, "o_proj")
)


def sdpa_forward(
inner: nn.Module,
x: mx.array,
ctx: PagedAttentionContext,
kv_cache: MetalPagedKVCache,
layer_idx: int,
) -> mx.array:
"""Full SDPA forward pass: project → norm → RoPE → Metal kernel.

Handles MHA, GQA, and MQA uniformly — the head ratio between
``inner.n_heads`` and ``inner.n_kv_heads`` is passed to the Metal
kernel which handles the broadcast internally.
"""
B, L, D = x.shape # noqa: N806

# --- Projections + reshape ---
queries = inner.q_proj(x).reshape(B, L, inner.n_heads, -1)
keys = inner.k_proj(x).reshape(B, L, inner.n_kv_heads, -1)
values = inner.v_proj(x).reshape(B, L, inner.n_kv_heads, -1)

# Qwen3 per-head RMSNorm before RoPE
if hasattr(inner, "q_norm"):
queries = inner.q_norm(queries)
if hasattr(inner, "k_norm"):
keys = inner.k_norm(keys)

# transpose → (B, heads, L, head_dim)
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.transpose(0, 2, 1, 3)

# --- RoPE (per-request position reset) ---
if not hasattr(inner, "rope"):
raise NotImplementedError(
f"Attention module {type(inner).__name__} does not have a 'rope' "
"attribute. Only RoPE-based models are supported by paged attention."
)

queries, keys = apply_packed_rope(
inner,
queries,
keys,
ctx.cu_seqlens,
offsets=ctx.offsets if ctx.offsets else None,
)

# --- Metal kernel dispatch ---
n_heads = queries.shape[1]
head_dim = queries.shape[3]

# Reshape to 3D: (1, heads, L, hd) → (L, heads, hd)
q_3d = mx.contiguous(queries[0].transpose(1, 0, 2).astype(kv_cache.dtype))
k_3d = mx.contiguous(keys[0].transpose(1, 0, 2).astype(kv_cache.dtype))
v_3d = mx.contiguous(values[0].transpose(1, 0, 2).astype(kv_cache.dtype))

slot_mapping = mx.array(ctx.slot_mapping, dtype=mx.int64)

# Build block_tables and seq_lens from context
max_blocks_per_seq = max(len(bt) for bt in ctx.block_tables)
block_tables_list = [
bt + [0] * (max_blocks_per_seq - len(bt)) for bt in ctx.block_tables
]
block_tables = mx.array(block_tables_list, dtype=mx.int32)
seq_lens = mx.array(ctx.context_lens, dtype=mx.int32)
cu_seqlens_q = mx.array(ctx.cu_seqlens, dtype=mx.int32)

# Allocate output buffer before eval so we can materialize everything in one call
out = mx.zeros((L, n_heads, head_dim), dtype=kv_cache.dtype)
mx.eval(q_3d, k_3d, v_3d, slot_mapping, block_tables, seq_lens, cu_seqlens_q, out)

ops = get_ops()

# Write K/V into paged cache BEFORE attention — the kernel reads from
# the paged cache via block_table, not from raw tensors.
ops.reshape_and_cache(
k_3d,
v_3d,
kv_cache.key_caches[layer_idx],
kv_cache.value_caches[layer_idx],
slot_mapping,
)

max_seq_len = max(ctx.context_lens)

ops.paged_attention_v2_online(
out,
q_3d,
kv_cache.key_caches[layer_idx],
kv_cache.value_caches[layer_idx],
kv_cache.num_kv_heads,
inner.scale,
0.0, # softcap (0 = disabled)
block_tables,
seq_lens,
cu_seqlens_q,
kv_cache.block_size,
max_seq_len,
-1, # sliding_window (-1 = disabled)
)

mx.synchronize()

# output: (L, n_heads, head_dim) → (B, L, n_heads * head_dim)
out = out.reshape(B, L, n_heads * head_dim)
return inner.o_proj(out)
Loading
Loading