diff --git a/tests/test_gdn_kernel.py b/tests/test_gdn_kernel.py new file mode 100644 index 00000000..cae1597e --- /dev/null +++ b/tests/test_gdn_kernel.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Deterministic golden tests for the fused GDN linear attention kernel. + +Verifies that the vllm-metal fused kernel produces identical output to +mlx_lm's Metal kernel across decode (T=1) and prefill (T>1) configs. +""" + +from __future__ import annotations + +import mlx.core as mx +import pytest +from mlx_lm.models.gated_delta import compute_g, gated_delta_kernel + +from vllm_metal.metal.linear_attention import fused_gdn_decode + +# Qwen3.5 shared dimensions +DK = 128 +DV = 128 +HK = 16 + +# Absolute tolerance for fp16 gating-order differences. +# Empirical max_abs is ~0.000031 for output and ~0.000061 for state. +# Set tolerance at 10x empirical to allow for hardware variance while +# still catching meaningful drift. +ATOL_Y = 0.001 +ATOL_S = 0.001 + + +def _make_inputs(B, T, Hv, dtype=mx.float16): # noqa: N803 + mx.random.seed(42) + sc = 0.1 + q = (mx.random.normal((B, T, HK, DK)) * sc).astype(dtype) + k = (mx.random.normal((B, T, HK, DK)) * sc).astype(dtype) + v = (mx.random.normal((B, T, Hv, DV)) * sc).astype(dtype) + a = (mx.random.normal((B, T, Hv)) * sc).astype(dtype) + b = (mx.random.normal((B, T, Hv)) * sc).astype(dtype) + A_log = (mx.random.normal((Hv,)) * sc).astype(dtype) # noqa: N806 + dt_bias = (mx.random.normal((Hv,)) * sc).astype(dtype) + state = mx.zeros((B, Hv, DV, DK), dtype=dtype) + mx.eval(q, k, v, a, b, A_log, dt_bias, state) + return q, k, v, a, b, A_log, dt_bias, state + + +def _run_reference(q, k, v, a, b, A_log, dt_bias, state): # noqa: N803 + """mlx_lm Metal kernel with pre-computed gating.""" + g = compute_g(A_log, a, dt_bias) + beta = mx.sigmoid(b) + mx.eval(g, beta) + state_copy = mx.array(state) + mx.eval(state_copy) + y, s = gated_delta_kernel(q, k, v, g, beta, state_copy) + mx.eval(y, s) + return y, s + + +def _run_fused(q, k, v, a, b, A_log, dt_bias, state): # noqa: N803 + """vllm-metal fused kernel.""" + y, s = fused_gdn_decode(q, k, v, a, b, A_log, dt_bias, state) + mx.eval(y, s) + return y, s + + +# --- Decode (T=1) --- + + +@pytest.mark.parametrize("B,Hv", [(1, 32), (1, 48), (4, 32), (8, 48)]) +def test_decode_matches_reference(B, Hv): # noqa: N803 + inputs = _make_inputs(B, T=1, Hv=Hv) + y_ref, s_ref = _run_reference(*inputs) + y_fused, s_fused = _run_fused(*inputs) + + y_diff = mx.abs(y_ref.astype(mx.float32) - y_fused.astype(mx.float32)).max().item() + s_diff = mx.abs(s_ref.astype(mx.float32) - s_fused.astype(mx.float32)).max().item() + + assert y_diff < ATOL_Y, f"y max_abs_diff={y_diff}" + assert s_diff < ATOL_S, f"state max_abs_diff={s_diff}" + + +# --- Prefill (T>1) --- + + +@pytest.mark.parametrize("T", [4, 16]) +def test_prefill_matches_reference(T): # noqa: N803 + inputs = _make_inputs(B=1, T=T, Hv=32) + y_ref, s_ref = _run_reference(*inputs) + y_fused, s_fused = _run_fused(*inputs) + + y_diff = mx.abs(y_ref.astype(mx.float32) - y_fused.astype(mx.float32)).max().item() + s_diff = mx.abs(s_ref.astype(mx.float32) - s_fused.astype(mx.float32)).max().item() + + assert y_diff < ATOL_Y, f"y max_abs_diff={y_diff}" + assert s_diff < ATOL_S, f"state max_abs_diff={s_diff}" + + +# --- Output shape --- + + +def test_output_shapes(): + B, T, Hv = 2, 8, 32 # noqa: N806 + inputs = _make_inputs(B, T, Hv) + y, s = _run_fused(*inputs) + + assert y.shape == (B, T, Hv, DV) + assert s.shape == (B, Hv, DV, DK) diff --git a/tools/README.md b/tools/README.md index f7bd1801..a3617f3a 100644 --- a/tools/README.md +++ b/tools/README.md @@ -98,3 +98,30 @@ vllm bench serve \ Key metric is **TTFT** — with prefix caching enabled, requests sharing the same prefix should show lower TTFT on cache hits. + +## GDN Linear Attention Benchmark + +Benchmark for GatedDeltaNet (GDN) linear attention kernels used in Qwen3.5's +linear attention layers. Compares four backends: + +- **fused**: vllm-metal fused kernel (gating + recurrence in 1 Metal dispatch) +- **metal**: mlx_lm Metal kernel (compute_g + sigmoid + kernel) +- **precomp**: mlx_lm kernel only (gating pre-computed, excluded from timing) +- **ops**: mlx_lm ops reference (mx.compile'd Python) + +```bash +# Correctness check only +PYTHONPATH=. python tools/bench_gdn_kernel.py --check + +# Full benchmark (all batch sizes, seq lens, head configs) +PYTHONPATH=. python tools/bench_gdn_kernel.py + +# Decode-only (T=1) with batch scaling +PYTHONPATH=. python tools/bench_gdn_kernel.py --seq-lens 1 --batch 1 4 8 16 32 + +# Qwen3.5-4B config (Hv=32) +PYTHONPATH=. python tools/bench_gdn_kernel.py --hv 32 + +# Qwen3.5-27B config (Hv=48) +PYTHONPATH=. python tools/bench_gdn_kernel.py --hv 48 +``` diff --git a/tools/bench_gdn_kernel.py b/tools/bench_gdn_kernel.py new file mode 100644 index 00000000..522ccd8d --- /dev/null +++ b/tools/bench_gdn_kernel.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark GatedDeltaNet (GDN) linear attention kernels. + +Compares three backends: + - metal: mlx_lm's Metal kernel (pre-computed g, beta + kernel dispatch) + - fused: vllm-metal fused kernel (gating + recurrence in one dispatch) + - ops: mlx_lm ops reference (mx.compile'd Python) + +Usage: + python tools/bench_gdn_kernel.py + python tools/bench_gdn_kernel.py --batch 1 4 8 --seq-lens 1 64 + python tools/bench_gdn_kernel.py --check # correctness only +""" + +from __future__ import annotations + +import argparse +import sys +import time + +import mlx.core as mx + +# Qwen3.5 GDN dimensions (shared across all model sizes) +DK = 128 # key head dim +DV = 128 # value head dim +HK = 16 # key heads + + +def _make_inputs( # noqa: N803 + batch, + seq_len, + n_k_heads, + n_v_heads, + key_dim, + val_dim, + dtype, +): + """Create synthetic inputs for benchmarking. + + Uses small values to avoid fp16 overflow in multi-step recurrence. + """ + mx.random.seed(42) + scale = 0.1 + q = (mx.random.normal((batch, seq_len, n_k_heads, key_dim)) * scale).astype(dtype) + k = (mx.random.normal((batch, seq_len, n_k_heads, key_dim)) * scale).astype(dtype) + v = (mx.random.normal((batch, seq_len, n_v_heads, val_dim)) * scale).astype(dtype) + a = (mx.random.normal((batch, seq_len, n_v_heads)) * scale).astype(dtype) + b = (mx.random.normal((batch, seq_len, n_v_heads)) * scale).astype(dtype) + a_log = (mx.random.normal((n_v_heads,)) * scale).astype(dtype) + dt_bias = (mx.random.normal((n_v_heads,)) * scale).astype(dtype) + state = mx.zeros((batch, n_v_heads, val_dim, key_dim), dtype=dtype) + mx.eval(q, k, v, a, b, a_log, dt_bias, state) + return q, k, v, a, b, a_log, dt_bias, state + + +def bench_one( # noqa: N803 + *, + backend, + batch, + seq_len, + n_k_heads, + n_v_heads, + key_dim, + val_dim, + warmup, + iters, + dtype, +) -> float: + """Run one benchmark config and return median ms per call.""" + from mlx_lm.models.gated_delta import ( + compute_g, + gated_delta_kernel, + gated_delta_ops, + ) + + q, k, v, a, b, a_log, dt_bias, state = _make_inputs( + batch, + seq_len, + n_k_heads, + n_v_heads, + key_dim, + val_dim, + dtype, + ) + + if backend == "fused": + from vllm_metal.metal.linear_attention import fused_gdn_decode + + def _fused_fn(): + return fused_gdn_decode(q, k, v, a, b, a_log, dt_bias, state) + + fn = _fused_fn + elif backend == "metal": + # Include compute_g + sigmoid in timing (fair comparison with fused) + def _metal_fn(): + g = compute_g(a_log, a, dt_bias) + beta = mx.sigmoid(b) + return gated_delta_kernel(q, k, v, g, beta, state) + + fn = _metal_fn + elif backend == "metal_precomp": + # Pre-computed gating (kernel-only timing) + g = compute_g(a_log, a, dt_bias) + beta = mx.sigmoid(b) + mx.eval(g, beta) + + def _precomp_fn(): + return gated_delta_kernel(q, k, v, g, beta, state) + + fn = _precomp_fn + elif backend == "ops": + + def _ops_fn(): + g = compute_g(a_log, a, dt_bias) + beta = mx.sigmoid(b) + return gated_delta_ops(q, k, v, g, beta, state) + + fn = _ops_fn + else: + raise ValueError(f"Unknown backend: {backend}") + + for _ in range(warmup): + y, s = fn() + mx.eval(y, s) + + times = [] + for _ in range(iters): + start = time.perf_counter() + y, s = fn() + mx.eval(y, s) + elapsed = time.perf_counter() - start + times.append(elapsed * 1000) + + times.sort() + return times[len(times) // 2] + + +def check_correctness(n_v_heads=32, dtype=mx.float16): + """Verify fused kernel matches mlx_lm's Metal kernel output.""" + from mlx_lm.models.gated_delta import compute_g, gated_delta_kernel + + from vllm_metal.metal.linear_attention import fused_gdn_decode + + print(f"Correctness check (Hv={n_v_heads}, dtype={dtype})...") + + for batch, seq_len in [(1, 1), (1, 16), (4, 1), (2, 8)]: + q, k, v, a, b, a_log, dt_bias, state = _make_inputs( + batch, + seq_len, + HK, + n_v_heads, + DK, + DV, + dtype, + ) + + # Reference: mlx_lm Metal kernel (pre-computed gating) + g = compute_g(a_log, a, dt_bias) + beta = mx.sigmoid(b) + mx.eval(g, beta) + state_copy = mx.array(state) + mx.eval(state_copy) + y_ref, s_ref = gated_delta_kernel(q, k, v, g, beta, state_copy) + mx.eval(y_ref, s_ref) + + # Fused kernel (use original state, not the copy) + y_fused, s_fused = fused_gdn_decode(q, k, v, a, b, a_log, dt_bias, state) + mx.eval(y_fused, s_fused) + + # Compare + y_abs = mx.abs(y_ref.astype(mx.float32) - y_fused.astype(mx.float32)) + s_abs = mx.abs(s_ref.astype(mx.float32) - s_fused.astype(mx.float32)) + y_diff = y_abs.max().item() + s_diff = s_abs.max().item() + + # fp16 gating order differences and near-zero outputs cause max_rel noise. + # Use absolute tolerance: for scaled inputs (0.1) outputs are O(0.01-0.1). + status = "PASS" if y_diff < 0.05 and s_diff < 0.05 else "FAIL" + print( + f" B={batch} T={seq_len}: " + f"y_maxabs={y_diff:.6f} s_maxabs={s_diff:.6f} [{status}]" + ) + + if status == "FAIL": + return False + return True + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark GDN kernel") + parser.add_argument( + "--batch", + type=int, + nargs="+", + default=[1, 4, 8], + ) + parser.add_argument( + "--seq-lens", + type=int, + nargs="+", + default=[1, 16, 64], + ) + parser.add_argument( + "--hv", + type=int, + nargs="+", + default=[32, 48], + ) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument( + "--dtype", + choices=["float16", "bfloat16"], + default="float16", + ) + parser.add_argument( + "--check", + action="store_true", + help="Run correctness check only", + ) + args = parser.parse_args() + + dtype = mx.float16 if args.dtype == "float16" else mx.bfloat16 + + # Validate correctness for every requested Hv before benchmarking + for hv in args.hv: + ok = check_correctness(n_v_heads=hv, dtype=dtype) + if not ok: + print(f"CORRECTNESS CHECK FAILED for Hv={hv} — aborting benchmark") + sys.exit(1) + print() + + if args.check: + return + + print(f"GDN Kernel Benchmark (Dk={DK}, Dv={DV}, Hk={HK})") + print(f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}") + print() + + header = ( + f"{'Hv':>4} | {'B':>3} | {'T':>5} | " + f"{'fused(ms)':>10} | {'metal(ms)':>10} | {'precomp(ms)':>11} | " + f"{'ops(ms)':>10} | {'f/m':>6}" + ) + print(header) + print("-" * len(header)) + + for n_v_heads in args.hv: + for batch in args.batch: + for seq_len in args.seq_lens: + common = { + "batch": batch, + "seq_len": seq_len, + "n_k_heads": HK, + "n_v_heads": n_v_heads, + "key_dim": DK, + "val_dim": DV, + "warmup": args.warmup, + "iters": args.iters, + "dtype": dtype, + } + + t_fused = bench_one(backend="fused", **common) + t_metal = bench_one(backend="metal", **common) + t_precomp = bench_one(backend="metal_precomp", **common) + t_ops = bench_one(backend="ops", **common) + + print( + f"{n_v_heads:4d} | {batch:3d} | {seq_len:5d} | " + f"{t_fused:10.3f} | {t_metal:10.3f} | {t_precomp:10.3f} | " + f"{t_ops:10.3f} | {t_fused / t_metal:6.2f}x" + ) + + print() + print("Backends:") + print(" fused = vllm-metal fused kernel (gating + recurrence, 1 dispatch)") + print(" metal = mlx_lm full path (compute_g + sigmoid + kernel)") + print(" precomp = mlx_lm kernel only (gating pre-computed, excluded from timing)") + print(" ops = mlx_lm ops reference (mx.compile'd Python loops)") + print(" f/m = fused / metal ratio (< 1.0 means fused wins)") + + +if __name__ == "__main__": + main() diff --git a/vllm_metal/metal/linear_attention.py b/vllm_metal/metal/linear_attention.py new file mode 100644 index 00000000..8bd0be4e --- /dev/null +++ b/vllm_metal/metal/linear_attention.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Fused GDN linear attention kernel for Metal. + +Fuses gating computation (A_log, a, b, dt_bias → g, beta) with the +recurrent state update, eliminating separate kernel dispatches. + +Uses ``mx.fast.metal_kernel`` for rapid prototyping. Once validated, +can be migrated to the nanobind C++ dispatch path (paged_ops.cpp) for +tighter integration with the paged attention pipeline. +""" + +from __future__ import annotations + +import mlx.core as mx + + +def _make_fused_gdn_kernel(): + """Build the fused GDN decode Metal kernel via mx.fast.metal_kernel.""" + if not mx.metal.is_available(): + return None + + source = """ + const int T = T_val; + auto n = thread_position_in_grid.z; + auto b_idx = n / Hv; + auto hv_idx = n % Hv; + auto hk_idx = hv_idx / (Hv / Hk); + constexpr int n_per_t = Dk / 32; + + // q, k: [B, T, Hk, Dk] + auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk; + auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk; + + // v, y: [B, T, Hv, Dv] + auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv; + y += b_idx * T * Hv * Dv + hv_idx * Dv; + + auto dk_idx = thread_position_in_threadgroup.x; + auto dv_idx = thread_position_in_grid.y; + + // state: [B, Hv, Dv, Dk] + auto i_state = state_in + (n * Dv + dv_idx) * Dk; + auto o_state = state_out + (n * Dv + dv_idx) * Dk; + + float state[n_per_t]; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = static_cast(i_state[s_idx]); + } + + // Per-head constants + float A_log_val = static_cast(A_log[hv_idx]); + float dt_bias_val = static_cast(dt_bias[hv_idx]); + + // a, b: [B, T, Hv] + auto a_ = a + b_idx * T * Hv; + auto b_ = b + b_idx * T * Hv; + + for (int t = 0; t < T; ++t) { + // Fused gating: g = exp(-exp(A_log) * softplus(a + dt_bias)) + float a_val = static_cast(a_[hv_idx]); + float x = a_val + dt_bias_val; + float sp = (x > 20.0f) ? x : log(1.0f + exp(x)); + float g_val = exp(-exp(A_log_val) * sp); + + // beta = sigmoid(b) + float b_val = static_cast(b_[hv_idx]); + float beta_val = 1.0f / (1.0f + exp(-b_val)); + + // Recurrence + float kv_mem = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] * g_val; + kv_mem += state[i] * k_[s_idx]; + } + kv_mem = simd_sum(kv_mem); + + auto delta = (v_[dv_idx] - kv_mem) * beta_val; + + float out = 0.0f; + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + state[i] = state[i] + k_[s_idx] * delta; + out += state[i] * q_[s_idx]; + } + out = simd_sum(out); + if (thread_index_in_simdgroup == 0) { + y[dv_idx] = static_cast(out); + } + + q_ += Hk * Dk; + k_ += Hk * Dk; + v_ += Hv * Dv; + y += Hv * Dv; + a_ += Hv; + b_ += Hv; + } + + for (int i = 0; i < n_per_t; ++i) { + auto s_idx = n_per_t * dk_idx + i; + o_state[s_idx] = static_cast(state[i]); + } + """ + + return mx.fast.metal_kernel( + name="fused_gdn_decode", + input_names=["q", "k", "v", "a", "b", "A_log", "dt_bias", "state_in", "T_val"], + output_names=["y", "state_out"], + source=source, + ) + + +_fused_kernel = _make_fused_gdn_kernel() + + +def fused_gdn_decode( + q: mx.array, + k: mx.array, + v: mx.array, + a: mx.array, + b: mx.array, + A_log: mx.array, # noqa: N803 + dt_bias: mx.array, + state: mx.array, +) -> tuple[mx.array, mx.array]: + """Fused GDN decode: gating + recurrence in one Metal dispatch. + + Args: + q: [B, T, Hk, Dk] queries + k: [B, T, Hk, Dk] keys + v: [B, T, Hv, Dv] values + a: [B, T, Hv] decay parameter + b: [B, T, Hv] gating parameter + A_log: [Hv] log-space decay base + dt_bias: [Hv] bias for decay computation + state: [B, Hv, Dv, Dk] recurrent state + + Returns: + (output [B, T, Hv, Dv], new_state [B, Hv, Dv, Dk]) + """ + B, T, Hk, Dk = k.shape # noqa: N806 + Hv, Dv = v.shape[2], v.shape[3] # noqa: N806 + + return _fused_kernel( + inputs=[q, k, v, a, b, A_log, dt_bias, state, T], + template=[ + ("InT", q.dtype), + ("Dk", Dk), + ("Dv", Dv), + ("Hk", Hk), + ("Hv", Hv), + ], + grid=(32, Dv, B * Hv), + threadgroup=(32, 4, 1), + output_shapes=[(B, T, Hv, Dv), state.shape], + output_dtypes=[q.dtype, q.dtype], + )