Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions tests/test_attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,17 @@ def test_find_layers_on_qwen3_model():


@pytest.mark.slow
def test_qwen35_paged_attention_raises_on_hybrid():
"""Loading Qwen/Qwen3.5-0.8B with paged attention raises RuntimeError
at setup — hybrid models are not yet supported on the paged path."""
from vllm import LLM
def test_qwen35_paged_attention_hybrid():
"""Qwen3.5 hybrid model loads and generates with paged attention."""
from vllm import LLM, SamplingParams

with pytest.MonkeyPatch.context() as mp:
mp.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
mp.setenv("VLLM_METAL_USE_PAGED_ATTENTION", "1")
mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.2")
mp.setenv("VLLM_METAL_MEMORY_FRACTION", "0.3")

with pytest.raises(RuntimeError, match="not yet supported for hybrid"):
LLM(model="Qwen/Qwen3.5-0.8B", max_model_len=512, max_num_seqs=1)
llm = LLM(model="Qwen/Qwen3.5-0.8B", max_model_len=512, max_num_seqs=1)
sp = SamplingParams(temperature=0, max_tokens=5)
outputs = llm.generate(["The capital of France is"], sp)
assert len(outputs) == 1
assert len(outputs[0].outputs[0].token_ids) > 0
3 changes: 3 additions & 0 deletions tests/test_paged_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def _make_paged_runner(num_layers: int = 2) -> mr.MetalModelRunner:
runner._paged_block_size = 4
runner._paged_request_seq_lens = {}
runner._request_states = {}
runner._gdn_req_to_slot = {}
runner._gdn_free_slots = []
runner.model_args = {}
runner._rust_state_manager = None
runner.num_layers = num_layers
runner.device = torch.device("cpu")
Expand Down
1 change: 1 addition & 0 deletions tests/test_v1_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_load_model_delegates_paged_attention_setup_decision(
expect_cap_call: int,
) -> None:
model_runner = MagicMock()
model_runner.is_hybrid = False
model_runner.should_setup_paged_attention.return_value = runner_allows_setup
worker = _make_worker(model_runner, use_paged_attention=use_paged_attention)
worker._setup_paged_attention = MagicMock()
Expand Down
162 changes: 162 additions & 0 deletions tools/test_qwen35_golden.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""Qwen3.5 golden token deterministic test: paged vs mlx_lm ground truth.

Verifies that the hybrid paged attention path (SDPA + GDN) produces the
same tokens as the MLX inline cache path for Qwen3.5.

Not in CI — requires local model weights.

Usage:
# Generate golden tokens (MLX inline cache, greedy):
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/test_qwen35_golden.py --gen-golden

# Run deterministic test (paged path vs golden):
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/test_qwen35_golden.py

# Custom model path:
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tools/test_qwen35_golden.py \
--model /path/to/Qwen3.5-0.8B
"""

import argparse
import os
import sys

os.environ.setdefault("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

from vllm import LLM, SamplingParams # noqa: E402

MODEL_DEFAULT = os.environ.get("QWEN35_MODEL_PATH", "Qwen/Qwen3.5-4B")
MAX_TOKENS = 20

PROMPTS = [
"The capital of France is",
"One plus one equals",
"The largest planet in our solar system is",
"Machine learning is a branch of",
]


def generate(model: str, max_tokens: int) -> dict[str, list[int]]:
"""Run greedy generation and return {prompt: token_ids}."""
llm = LLM(model=model, max_model_len=512, max_num_seqs=1)
sp = SamplingParams(temperature=0, max_tokens=max_tokens)
outputs = llm.generate(PROMPTS, sp)
result = {}
for o in outputs:
result[o.prompt] = list(o.outputs[0].token_ids)
return result


def print_golden(results: dict[str, list[int]], label: str) -> None:
"""Print golden token dict for copy-paste."""
print(f"\nGOLDEN_{label} = {{")
for prompt, ids in results.items():
pad = 55 - len(prompt)
print(f" {prompt!r}:{' ' * max(pad, 1)}{ids},")
print("}")


def _run_in_subprocess(
model: str, max_tokens: int, paged: bool
) -> dict[str, list[int]]:
"""Run generation in a subprocess to avoid memory interference."""
import json
import subprocess

env = os.environ.copy()
env["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
if paged:
env["VLLM_METAL_USE_PAGED_ATTENTION"] = "1"
env.setdefault("VLLM_METAL_MEMORY_FRACTION", "0.5")

script = f"""
import os, json
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
if {paged!r}:
os.environ["VLLM_METAL_USE_PAGED_ATTENTION"] = "1"
os.environ.setdefault("VLLM_METAL_MEMORY_FRACTION", "0.5")
from vllm import LLM, SamplingParams
llm = LLM(model={model!r}, max_model_len=512, max_num_seqs=1)
sp = SamplingParams(temperature=0, max_tokens={max_tokens})
prompts = {PROMPTS!r}
outputs = llm.generate(prompts, sp)
result = {{o.prompt: list(o.outputs[0].token_ids) for o in outputs}}
print("GOLDEN_JSON:" + json.dumps(result))
"""
proc = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
env=env,
timeout=600,
)
if proc.returncode != 0:
print(proc.stderr[-2000:] if len(proc.stderr) > 2000 else proc.stderr)
raise RuntimeError(f"Subprocess failed (paged={paged})")

for line in proc.stdout.splitlines():
if line.startswith("GOLDEN_JSON:"):
return json.loads(line[len("GOLDEN_JSON:") :])
raise RuntimeError("No GOLDEN_JSON output found")


def run_test(model: str, max_tokens: int) -> bool:
"""Compare paged path output against MLX inline cache path."""
print("=== Step 1: MLX inline cache (ground truth) ===")
mlx_results = _run_in_subprocess(model, max_tokens, paged=False)

print("=== Step 2: Paged attention path ===")
paged_results = _run_in_subprocess(model, max_tokens, paged=True)

# Compare
print("\n=== Results ===")
all_match = True
for prompt in PROMPTS:
mlx_ids = mlx_results[prompt]
paged_ids = paged_results[prompt]
match = mlx_ids == paged_ids
status = "MATCH" if match else "MISMATCH"
if not match:
all_match = False
# Find first divergence
for i, (a, b) in enumerate(zip(mlx_ids, paged_ids, strict=False)):
if a != b:
print(
f" [{status}] {prompt!r} — diverges at token {i}: "
f"mlx={a} vs paged={b}"
)
break
else:
print(
f" [{status}] {prompt!r} — length differs: "
f"mlx={len(mlx_ids)} vs paged={len(paged_ids)}"
)
else:
print(f" [{status}] {prompt!r}")

print(f"\n{'ALL PASSED' if all_match else 'SOME MISMATCHES'}")
return all_match


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--model", default=MODEL_DEFAULT)
parser.add_argument("--max-tokens", type=int, default=MAX_TOKENS)
parser.add_argument(
"--gen-golden", action="store_true", help="Just print golden token IDs and exit"
)
args = parser.parse_args()

if args.gen_golden:
paged = os.environ.get("VLLM_METAL_USE_PAGED_ATTENTION", "0") == "1"
label = "PAGED" if paged else "MLX"
print(f"Generating golden tokens ({label} path, {args.model})")
results = _run_in_subprocess(args.model, args.max_tokens, paged=paged)
print_golden(results, label)
else:
ok = run_test(args.model, args.max_tokens)
sys.exit(0 if ok else 1)
4 changes: 4 additions & 0 deletions vllm_metal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def _register() -> str | None:
"""
_apply_macos_defaults()

from vllm_metal.compat import apply_compat_patches

apply_compat_patches()

from vllm_metal.platform import MetalPlatform

if MetalPlatform.is_available():
Expand Down
101 changes: 101 additions & 0 deletions vllm_metal/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
"""Compatibility patches for vLLM + transformers version mismatches.

Applied once at platform registration time. Each patch is guarded by
try/except so it degrades silently if the target module changes.
"""

from __future__ import annotations

import logging

logger = logging.getLogger(__name__)

_APPLIED = False


def apply_compat_patches() -> None:
"""Apply all known compatibility patches (idempotent)."""
global _APPLIED # noqa: PLW0603
if _APPLIED:
return
_APPLIED = True
_patch_qwen35_rope_validation()


def _patch_qwen35_rope_validation() -> None:
"""Fix vLLM 0.17.1 Qwen3.5 config vs transformers >=5.4 rope validation.

vLLM's ``Qwen3_5TextConfig.__init__`` hardcodes
``kwargs["ignore_keys_at_rope_validation"] = [...]`` (a list), but
transformers 5.4+ does ``received_keys -= ignore_keys`` which requires
a set.

Upstream: vllm-project/vllm#34604 fixed this but was reverted in #34610.
Remove this patch when vllm-metal upgrades to a vLLM version with the fix.
"""
from importlib.util import find_spec

if find_spec("vllm.transformers_utils.configs.qwen3_5") is None:
return

try:
from transformers.modeling_rope_utils import RopeConfigBase

rope_config_base = RopeConfigBase
except ImportError:
rope_config_base = None

if rope_config_base is None:
# Try the direct path
try:
import transformers.modeling_rope_utils as _rope

_orig_check = _rope._check_received_keys

def _safe_check(
rope_type,
received_keys,
required_keys,
optional_keys=None,
ignore_keys=None,
):
if ignore_keys is not None and isinstance(ignore_keys, list):
ignore_keys = set(ignore_keys)
return _orig_check(
rope_type, received_keys, required_keys, optional_keys, ignore_keys
)

_rope._check_received_keys = _safe_check
logger.debug("Patched _check_received_keys for rope validation compat")
return
except (ImportError, AttributeError):
pass

# Fallback: patch the static method on PreTrainedConfig if available
try:
from transformers import PreTrainedConfig

if hasattr(PreTrainedConfig, "_check_received_keys"):
_orig_check = PreTrainedConfig._check_received_keys

@staticmethod
def _safe_check(
rope_type,
received_keys,
required_keys,
optional_keys=None,
ignore_keys=None,
):
if ignore_keys is not None and isinstance(ignore_keys, list):
ignore_keys = set(ignore_keys)
return _orig_check(
rope_type, received_keys, required_keys, optional_keys, ignore_keys
)

PreTrainedConfig._check_received_keys = _safe_check
logger.debug(
"Patched PreTrainedConfig._check_received_keys for rope compat"
)
except (ImportError, AttributeError):
pass
13 changes: 13 additions & 0 deletions vllm_metal/metal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def _build_v2_paged_attention_source() -> str:
return "\n".join(parts)


def _build_gdn_source() -> str:
"""GDN linear attention kernel source."""
parts = [
_read_metal_source(_KERNELS_V2_DIR / "utils.metal"),
_read_metal_source(_KERNELS_V2_DIR / "gdn_linear_attention.metal"),
]
return "\n".join(parts)


def metal_unified_attention(
q, # [total_q_tokens, num_q_heads, head_size]
k, # [num_blocks, block_size, num_kv_heads, head_size]
Expand Down Expand Up @@ -212,6 +221,10 @@ def get_ops() -> ModuleType:
v2_src = _build_v2_paged_attention_source()
mod.init_v2_library(v2_src)

# 5. Initialise GDN linear attention library
gdn_src = _build_gdn_source()
mod.init_gdn_library(gdn_src)

_ops_module = mod
logger.info("Native paged-attention Metal kernels loaded")
return mod
Loading
Loading