[Kernel] Fused GDN linear attention kernel for Qwen3.5#186
[Kernel] Fused GDN linear attention kernel for Qwen3.5#186ricky-chaoju wants to merge 5 commits into
Conversation
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
|
Quick thought: |
|
BTW. We need to design a new abstraction for the attention backend interception & dispatch. The current behavior is to intercept all the attention to the unified v2 attention. This is inherently incompatible with linear attention and MLA, so we need some a bit more complex dispatch abstraction. |
|
Agreed on both points. This PR is a prototype, C++ nanobind dispatch and backend abstraction are follow-up work. |
|
Thanks! will take a closer look tomorrow. |
LxYuan0420
left a comment
There was a problem hiding this comment.
MLX lazy eval already fuses the dispatch path so what exactly were you fusing that MLX wasn't already doing?
The fused kernel puts the gating math (softplus, exp, sigmoid) inside the shader instead of running them as separate MLX ops. |
Signed-off-by: RickyChen / 陳昭儒 <rickychen@infinirc.com>
|
so what existing path do we delete/simplify after this merges? If none, we are only increasing maintenance surface |
WindChimeRan
left a comment
There was a problem hiding this comment.
Thanks for the kernel work! However, this can't integrate with the paged attention pipeline as-is — mx.fast.metal_kernel allocates fresh output buffers per call, so it can't do in-place state updates into a managed cache. The linear attention kernel needs to work like paged_attention_v2_online: read/write via block tables, handle varlen batches, dispatch through the C++ nanobind path.
We're landing an attention dispatch refactor #201 that adds the linear_attention_forward integration point and per-layer patching for hybrid models. Let's build the GDN kernel on top of that foundation.
Summary
A_log, a, b, dt_bias → g, beta) with the recurrent state update into a singlemx.fast.metal_kerneldispatchtools/bench_gdn_kernel.py) comparing 4 backends: fused, mlx_lm Metal, mlx_lm precomputed, and ops referenceRef: #148 (Roadmap — "After Stage 3: Linear attention kernel for Qwen3.5")
Kernel Design
Follows the same threading model as mlx_lm's
gated_delta_kernel:(32, Dv, B * Hv), Threadgroup:(32, 4, 1)n_per_t = Dk/32 = 4elementssimd_sum()for cross-thread dot-product reductionDifference from mlx_lm: computes gating on-the-fly in the shader (softplus + exp for decay, sigmoid for beta) instead of requiring pre-computed
gandbetatensors. Matches upstream vLLM'sfused_sigmoid_gating_delta_rule_update_kernelapproach.Benchmark Results
Isolated kernel benchmark (Dk=128, Dv=128, Hk=16, float16):