Skip to content

Add CPU kernels for linear attention contrib ops#27835

Open
OmarAzizi wants to merge 1 commit intomicrosoft:rama/linear-attnfrom
OmarAzizi:cpu/linear-attention-kernels
Open

Add CPU kernels for linear attention contrib ops#27835
OmarAzizi wants to merge 1 commit intomicrosoft:rama/linear-attnfrom
OmarAzizi:cpu/linear-attention-kernels

Conversation

@OmarAzizi
Copy link
Copy Markdown

Description

Implemented CPU execution provider kernels for the three linear attention contrib ops in the for linear attention / recurrent state-update mechanisms used by modern hybrid LLMs (Qwen3.5, Jamba, RWKV-6, FalconMamba, etc.):

  • LinearAttentionRecurrent: Single-token recurrent decode step supporting linear, gated, delta, and gated_delta update rules. Computes the full state update (decay, retrieve, delta, write, readout) in float32 for numerical stability across all sequence lengths.
  • LinearAttentionChunkParallel: Prefill kernel that processes a full input sequence by running the recurrent step sequentially for all T tokens. The CUDA chunk-parallel WY decomposition is not used on the CPU.
  • CausalConv1DWithState: Depthwise causal 1D convolution with carry state and optional SiLU activation.

All ops support float32, float16, and bfloat16. fp16/bf16 inputs are converted to float32 internally for accumulation, matching the precision behavior of the CUDA kernels.

Note: The kernels compile, and the kernel symbols are correctly present in the built binary. However, end-to-end Python testing with ONNXRuntime was blocked. The ops are not registering at runtime despite the kernels being linked. The root cause appears to be in the schema registration in bert_defs.cc, which affects both the CPU and CUDA kernels. Any input on how this could be fixed would be appreciated.

Motivation and Context

The CUDA kernels for these ops were added in commit 3966afb without the CPU kernels, which would result in inference failing for models like Qwen3.5 and Jamba on CPU-only machines.

Ref: onnx/onnx#7689

Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
@OmarAzizi
Copy link
Copy Markdown
Author

@microsoft-github-policy-service agree

@guschmue
Copy link
Copy Markdown
Contributor

While implementing this we are discussing some changes to the ops signatures to make them more practical.

  1. instead of using 2 ops for LinearAttention it will be just 1 called LinearAttention().

  2. still under discussion is if we need to transpose the inputs to LinearAttention() to avoid the transpose of every input that is there if we don't.

The signatures we expect for webgpu are here - trying to see if we can finalize the signature today.

https://github.com/microsoft/onnxruntime/blob/gs/wgpu-lattn/onnxruntime/core/graph/contrib_ops/bert_defs.cc#L2236
https://github.com/microsoft/onnxruntime/blob/gs/wgpu-lattn/onnxruntime/core/graph/contrib_ops/bert_defs.cc#L2323

@guschmue
Copy link
Copy Markdown
Contributor

we are working on the signature here:
#27842

@OmarAzizi
Copy link
Copy Markdown
Author

OmarAzizi commented Mar 25, 2026

@guschmue Thanks for the heads up! I'll hold off on any further changes until the signatures are finalized. I'm happy to update the CPU kernels to match the new LinearAttention op once the interface is settled. Let me know if there's anything I can do in the meantime.

@guschmue
Copy link
Copy Markdown
Contributor

ok, I updated the contrib ops PR to the latest signature
from onnx/onnx#7767
Also added the unit tests for the ops in the same PR.
#27842

A working fp16/q4 models is here:
https://huggingface.co/schmuell/Qwen3.5-0.8B

A working implementation for webgpu is in this PR:
#27896

@OmarAzizi
Copy link
Copy Markdown
Author

Hi @guschmue, thank you a lot for the update and for sharing the latest changes and tests.

I’ve been a bit busy over the past few days, but I’ll review the updated signatures and update my PR this week.

Since I’m still relatively new to the codebase and contrib ops workflow, I’d really appreciate it if you have any notes or tips I should keep in mind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants