Add LinearAttention and CausalConvState ops for Qwen3.5#27907
Add LinearAttention and CausalConvState ops for Qwen3.5#27907apsonawane merged 17 commits intomainfrom
Conversation
…ith 'import' and 'import from' Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR adds new com.microsoft contrib operators and implementations to support efficient inference for Qwen3.5-style hybrid decoder models in ONNX Runtime.
Changes:
- Adds operator schemas for
LinearAttentionandCausalConvWithState(opset 1,com.microsoft) including type/shape inference. - Adds CUDA kernels for
LinearAttentionandCausalConvWithState, and registers them in the CUDA contrib registry. - Adds CPU kernels for
LinearAttentionandCausalConvWithState, registers them in the CPU contrib registry, and introduces a Python parity test against PyTorch.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py | Adds CUDA/CPU parity tests comparing ORT outputs vs PyTorch reference implementations. |
| onnxruntime/core/graph/contrib_ops/ms_opset.h | Exposes new schemas in the Microsoft opset v1 registry. |
| onnxruntime/core/graph/contrib_ops/bert_defs.cc | Defines schemas + docstrings + type/shape inference for both new ops. |
| onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc | Registers CUDA kernels for the two new ops. |
| onnxruntime/contrib_ops/cuda/bert/linear_attention.h | Declares CUDA LinearAttention kernel class. |
| onnxruntime/contrib_ops/cuda/bert/linear_attention.cc | Implements CUDA LinearAttention kernel wiring and launch. |
| onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.h | Declares CUDA LinearAttention kernel launcher. |
| onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu | Implements CUDA fused recurrent linear attention kernel(s). |
| onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h | Declares CUDA CausalConvWithState kernel class. |
| onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc | Implements CUDA CausalConvWithState kernel wiring and launch. |
| onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h | Declares CUDA causal conv kernel launcher. |
| onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu | Implements CUDA fused causal depthwise conv1d + state management (+ SiLU). |
| onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc | Registers CPU kernels for the two new ops. |
| onnxruntime/contrib_ops/cpu/bert/linear_attention.h | Declares CPU LinearAttention kernel class. |
| onnxruntime/contrib_ops/cpu/bert/linear_attention.cc | Implements CPU LinearAttention kernel (thread-parallel over heads). |
| onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h | Declares CPU CausalConvWithState kernel class. |
| onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc | Implements CPU CausalConvWithState for ndim=1 with decode/prefill paths. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
CausalConvWithState: It is only 1d? I recall we also need 2d for nemotron. |
tianleiwu
left a comment
There was a problem hiding this comment.
Two correctness issues need fixing before merge; remaining items are suggestions/nitpicks.
High priority:
- Divergent
__syncthreads()inCausalConvPrefillKernelBatched(see inline) - Silent numeric corruption when
max(d_k, d_v) > max_threads_per_blockin the genericLinearAttentionRecurrentKernel(see inline)
Suggestions / nitpicks: see inline comments on chunk_size_, CPU fp16 registration, redundant cudaMemsetAsync, and test coverage gaps.
tianleiwu
left a comment
There was a problem hiding this comment.
Follow-up review: all 6 prior comments addressed. 3 new suggestions remain (shared memory validation, inverse GQA test, per-key-dim decay test coverage).
tianleiwu
left a comment
There was a problem hiding this comment.
One new correctness issue found (generic kernel s_scratch race). Previous three threads (smem validation, inverse GQA test, per-key-dim decay test) are all addressed and have been resolved above.
⚠️ High: Inter-warp race on s_scratch in generic LinearAttentionRecurrentKernel
In LinearAttentionRecurrentKernel (the fallback for head dims not in {64,128}), the standalone if (needs_retrieval) block stores k into s_scratch and then overwrites it with retrieval results — without using the separately-allocated k_buf:
if (needs_retrieval) {
if (tid < d_k) {
s_scratch[tid] = kt_val; // ← k stored in s_scratch
}
__syncthreads();
if (tid < d_v) {
float acc = 0.0f;
for (int i = 0; i < d_k; ++i)
acc += S_smem[i * d_v + tid] * s_scratch[i];
s_scratch[tid] = acc; // ← overwrites s_scratch[tid] — races with other warps reading s_scratch[tid]
}
__syncthreads();
}When d_k > 32 (multiple warps), warp W1 can finish its inner-product loop and write s_scratch[32..63] before warp W0 reaches loop index i=32, causing W0 to read corrupted k values. The shared memory layout already has k_buf (smem + d_k * d_v) allocated for exactly this purpose — it just isn't used here. The fixed-shape kernel avoids the race by consistently using k_buf.
Fix (mirror the fixed-shape kernel):
if (needs_retrieval) {
if (tid < d_k) {
k_buf[tid] = kt_val; // k_buf is already allocated at smem + d_k*d_v
}
__syncthreads();
if (tid < d_v) {
float acc = 0.0f;
for (int i = 0; i < d_k; ++i)
acc += S_smem[i * d_v + tid] * k_buf[i];
s_scratch[tid] = acc;
}
__syncthreads();
}Affected rule: delta (needs_retrieval=true, needs_decay=false; the fused path that correctly uses k_buf is not taken for this rule). gated_delta is unaffected (fused path). linear/gated are unaffected (needs_retrieval=false).
Practical exposure: Qwen3.5 (d_k=d_v=64/128) hits the fixed-shape kernels — this race is only triggered by callers using the delta rule with non-standard head dimensions.
Remaining items (suggestions/nitpicks) captured in PR_27907_REVIEW_48ab7e9_claude_1.md in the workspace.
tianleiwu
left a comment
There was a problem hiding this comment.
One new high-priority finding: inter-warp race on s_scratch in the generic LinearAttentionRecurrentKernel. Previous three threads (smem validation, inverse GQA test, per-key-dim decay test) are addressed and resolved. Full details in PR_27907_REVIEW_48ab7e9_claude_1.md.
High: In LinearAttentionRecurrentKernel (fallback for head dims not in {64,128}), the standalone needs_retrieval block stores k into s_scratch then overwrites it with retrieval results. When d_k > 32, warp W1 can finish its inner-product loop and write s_scratch[32..63] before warp W0 reaches loop index i=32, causing silent numeric corruption. The fix is to use k_buf (already allocated at smem + d_k*d_v) as the fixed-shape kernel does. Affected rule: delta with non-standard head dimensions.
tianleiwu
left a comment
There was a problem hiding this comment.
Follow-up review on commit cdf987d:
All findings from both prior rounds are now addressed:
Resolved in this commit (cdf987d):
- ✅ Inter-warp race on
s_scratchin genericLinearAttentionRecurrentKernel: fixed by switching tok_bufin the standaloneneeds_retrievalpath. Inline comment posted on the fix.
Resolved in prior commits:
- ✅ Divergent
__syncthreads()inCausalConvPrefillKernelBatched - ✅ Silent overflow when
max(d_k, d_v) > max_threads_per_blockin generic kernel - ✅
chunk_size_stored but never used (removed / documented) - ✅ Missing
cudaFuncSetAttributefor >48 KB shared memory in both prefill kernels - ✅
test_linear_attention_inverse_gqanow usesq=8, kv=16(true inverse GQA) - ✅ Per-key-dim decay path now covered by
test_linear_attention_per_key_dim_decay - ✅ CPU fp16 schema vs kernel registration gap documented in comments
- ✅ All four update rules (linear, gated, delta, gated_delta) tested on CPU and CUDA
Remaining (suggestions — not blocking):
- Lintrunner RUFF formatting on the test file still open (threads 3–4 from github-advanced-security); run
lintrunner -ato auto-fix. - Consider documenting the
seq_len <= 128threshold inLaunchCausalConvWithStateKernelto explain the occupancy trade-off for the batched vs. single-channel kernel selection.
LGTM otherwise.
Version bump to 1.25.1. This cherry-picks the following commits for the release: | Commit ID | PR Number | Commit Title | |-----------|-----------|-------------| | e532c21 | #27842 | linear attention signature | | 410f5a8 | #27752 | +rotemb, +rmsnorm, reshape->opset-25, transpose->opset-24 | | 0fedb26 | #27907 | Add LinearAttention and CausalConvState ops for Qwen3.5 | | 3ac6040 | #27996 | webgpu support for qwen3.5 | | c36c422 | #27998 | [WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches | | 94f32ec | #27289 | [CORE]: Improve filesystem error messages during Linux device discovery | | dce77a3 | #28118 | Fix lack of auth on python packaging | --------- Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: eserscor <erscor@microsoft.com> Co-authored-by: Sanaa Hamel <sanaahamel@microsoft.com> Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com> Co-authored-by: Stephan Seitz <sseitz@nvidia.com> Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Adds custom CUDA and CPU kernels for linear attention and causal 1D convolution with state, enabling efficient inference of Qwen3.5 hybrid decoder models in ONNX Runtime.
New Operators
LinearAttention— Implements the GatedDeltaNet recurrent linear attention mechanism:CausalConvWithState— Implements causal 1D convolution with persistent state for autoregressive decoding:Op Definitions
com.microsoftdomain (opset 1)bert_defs.ccTesting
test_parity_linear_attention_causal_conv.py) validates CUDA and CPU kernels against PyTorch reference implementations from the FLA (Flash Linear Attention) library