Add CPU kernels for linear attention contrib ops#27835
Add CPU kernels for linear attention contrib ops#27835OmarAzizi wants to merge 1 commit intomicrosoft:rama/linear-attnfrom
Conversation
Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
|
@microsoft-github-policy-service agree |
|
While implementing this we are discussing some changes to the ops signatures to make them more practical.
The signatures we expect for webgpu are here - trying to see if we can finalize the signature today. https://github.com/microsoft/onnxruntime/blob/gs/wgpu-lattn/onnxruntime/core/graph/contrib_ops/bert_defs.cc#L2236 |
|
we are working on the signature here: |
|
@guschmue Thanks for the heads up! I'll hold off on any further changes until the signatures are finalized. I'm happy to update the CPU kernels to match the new |
|
ok, I updated the contrib ops PR to the latest signature A working fp16/q4 models is here: A working implementation for webgpu is in this PR: |
|
Hi @guschmue, thank you a lot for the update and for sharing the latest changes and tests. I’ve been a bit busy over the past few days, but I’ll review the updated signatures and update my PR this week. Since I’m still relatively new to the codebase and contrib ops workflow, I’d really appreciate it if you have any notes or tips I should keep in mind. |
Description
Implemented CPU execution provider kernels for the three linear attention contrib ops in the for linear attention / recurrent state-update mechanisms used by modern hybrid LLMs (Qwen3.5, Jamba, RWKV-6, FalconMamba, etc.):
LinearAttentionRecurrent: Single-token recurrent decode step supporting linear, gated, delta, and gated_delta update rules. Computes the full state update (decay, retrieve, delta, write, readout) in float32 for numerical stability across all sequence lengths.LinearAttentionChunkParallel: Prefill kernel that processes a full input sequence by running the recurrent step sequentially for all T tokens. The CUDA chunk-parallel WY decomposition is not used on the CPU.CausalConv1DWithState: Depthwise causal 1D convolution with carry state and optional SiLU activation.All ops support
float32,float16, andbfloat16. fp16/bf16 inputs are converted tofloat32internally for accumulation, matching the precision behavior of the CUDA kernels.Note: The kernels compile, and the kernel symbols are correctly present in the built binary. However, end-to-end Python testing with ONNXRuntime was blocked. The ops are not registering at runtime despite the kernels being linked. The root cause appears to be in the schema registration in
bert_defs.cc, which affects both the CPU and CUDA kernels. Any input on how this could be fixed would be appreciated.Motivation and Context
The CUDA kernels for these ops were added in commit
3966afbwithout the CPU kernels, which would result in inference failing for models like Qwen3.5 and Jamba on CPU-only machines.Ref: onnx/onnx#7689