Skip to content

[Kernel] Fused GDN linear attention kernel for Qwen3.5#186

Closed
ricky-chaoju wants to merge 5 commits into
vllm-project:mainfrom
ricky-chaoju:gdn-linear-attention-kernel
Closed

[Kernel] Fused GDN linear attention kernel for Qwen3.5#186
ricky-chaoju wants to merge 5 commits into
vllm-project:mainfrom
ricky-chaoju:gdn-linear-attention-kernel

Conversation

@ricky-chaoju
Copy link
Copy Markdown
Collaborator

@ricky-chaoju ricky-chaoju commented Mar 21, 2026

Summary

  • Add fused Metal kernel for GatedDeltaNet (GDN) linear attention used in Qwen3.5's linear attention layers
  • Fuses gating computation (A_log, a, b, dt_bias → g, beta) with the recurrent state update into a single mx.fast.metal_kernel dispatch
  • Add benchmark tool (tools/bench_gdn_kernel.py) comparing 4 backends: fused, mlx_lm Metal, mlx_lm precomputed, and ops reference
  • Add 7 deterministic golden tests validating fused kernel against mlx_lm Metal kernel

Ref: #148 (Roadmap — "After Stage 3: Linear attention kernel for Qwen3.5")

Kernel Design

Follows the same threading model as mlx_lm's gated_delta_kernel:

  • Grid: (32, Dv, B * Hv), Threadgroup: (32, 4, 1)
  • Each thread handles n_per_t = Dk/32 = 4 elements
  • simd_sum() for cross-thread dot-product reduction
  • State kept in registers across the time loop

Difference from mlx_lm: computes gating on-the-fly in the shader (softplus + exp for decay, sigmoid for beta) instead of requiring pre-computed g and beta tensors. Matches upstream vLLM's fused_sigmoid_gating_delta_rule_update_kernel approach.

Benchmark Results

Isolated kernel benchmark (Dk=128, Dv=128, Hk=16, float16):

Correctness check (Hv=32, dtype=mlx.core.float16)...
  B=1 T=1: y_maxabs=0.000031 s_maxabs=0.000061 [PASS]
  B=1 T=16: y_maxabs=0.000031 s_maxabs=0.000061 [PASS]
  B=4 T=1: y_maxabs=0.000031 s_maxabs=0.000061 [PASS]
  B=2 T=8: y_maxabs=0.000031 s_maxabs=0.000061 [PASS]

  Hv |   B |     T |  fused(ms) |  metal(ms) | precomp(ms) |    ops(ms) |    f/m
--------------------------------------------------------------------------------
  32 |   1 |     1 |      0.232 |      0.240 |      0.202 |      0.320 |   0.97x
  32 |   1 |    16 |      0.285 |      0.270 |      0.232 |      1.261 |   1.06x
  32 |   1 |    64 |      0.477 |      0.339 |      0.285 |      4.002 |   1.41x
  32 |   4 |     1 |      0.247 |      0.258 |      0.232 |      0.370 |   0.96x
  32 |   8 |     1 |      0.278 |      0.268 |      0.226 |      0.381 |   1.04x
  48 |   1 |     1 |      0.239 |      0.251 |      0.211 |      0.324 |   0.96x
  48 |   4 |     1 |      0.256 |      0.257 |      0.219 |      0.366 |   0.99x
  48 |   8 |     1 |      0.296 |      0.286 |      0.245 |      0.434 |   1.03x
  • Decode (T=1): fused is roughly on par with mlx_lm full path (f/m 0.96-1.04x). Saves one dispatch but adds ALU for on-the-fly gating — net effect is within noise
  • Prefill (T>1): fused is slower (1.4-2.5x) due to per-timestep ALU overhead for on-the-fly gating computation
  • End-to-end model impact: ~0.5% on both Qwen3.5-4B and 27B — models are memory-bandwidth-bound, MLX lazy eval already fuses the separate dispatch path

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@ricky-chaoju ricky-chaoju marked this pull request as ready for review March 21, 2026 05:39
@WindChimeRan
Copy link
Copy Markdown
Collaborator

Quick thought: mx.fast.metal_kernel allocates fresh output buffers every call (by-value), which means the recurrent state can't be updated in-place — this won't work for paged/continuous-batching integration where we need by-reference writes into a managed state pool (like reshape_and_cache does via the C++ path). Worth keeping in mind as we plan the migration from prototype to production.

@WindChimeRan
Copy link
Copy Markdown
Collaborator

BTW. We need to design a new abstraction for the attention backend interception & dispatch.

The current behavior is to intercept all the attention to the unified v2 attention. This is inherently incompatible with linear attention and MLA, so we need some a bit more complex dispatch abstraction.

@ricky-chaoju
Copy link
Copy Markdown
Collaborator Author

Agreed on both points. This PR is a prototype, C++ nanobind dispatch and backend abstraction are follow-up work.

@WindChimeRan
Copy link
Copy Markdown
Collaborator

Thanks! will take a closer look tomorrow.

Copy link
Copy Markdown
Collaborator

@LxYuan0420 LxYuan0420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLX lazy eval already fuses the dispatch path so what exactly were you fusing that MLX wasn't already doing?

@ricky-chaoju
Copy link
Copy Markdown
Collaborator Author

ricky-chaoju commented Mar 22, 2026

MLX lazy eval already fuses the dispatch path so what exactly were you fusing that MLX wasn't already doing?

The fused kernel puts the gating math (softplus, exp, sigmoid) inside the shader instead of running them as separate MLX ops.
Isolated kernel benchmarks show ~4% improvement for decode, but in full model forward it's within noise (~0.4%), since MLX lazy evaluation already batches these ops across layers.

Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
@LxYuan0420
Copy link
Copy Markdown
Collaborator

so what existing path do we delete/simplify after this merges? If none, we are only increasing maintenance surface

Copy link
Copy Markdown
Collaborator

@WindChimeRan WindChimeRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the kernel work! However, this can't integrate with the paged attention pipeline as-is — mx.fast.metal_kernel allocates fresh output buffers per call, so it can't do in-place state updates into a managed cache. The linear attention kernel needs to work like paged_attention_v2_online: read/write via block tables, handle varlen batches, dispatch through the C++ nanobind path.

We're landing an attention dispatch refactor #201 that adds the linear_attention_forward integration point and per-layer patching for hybrid models. Let's build the GDN kernel on top of that foundation.

@ricky-chaoju ricky-chaoju deleted the gdn-linear-attention-kernel branch May 3, 2026 09:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants