Skip to content

Add LinearAttention and CausalConvState ops for Qwen3.5#27907

Merged
apsonawane merged 17 commits intomainfrom
asonawane/linearattention
Apr 7, 2026
Merged

Add LinearAttention and CausalConvState ops for Qwen3.5#27907
apsonawane merged 17 commits intomainfrom
asonawane/linearattention

Conversation

@apsonawane
Copy link
Copy Markdown
Contributor

Adds custom CUDA and CPU kernels for linear attention and causal 1D convolution with state, enabling efficient inference of Qwen3.5 hybrid decoder models in ONNX Runtime.

New Operators

LinearAttention — Implements the GatedDeltaNet recurrent linear attention mechanism:

  • Fused kernel computing gated delta-rule update of a recurrent state matrix
  • Supports both prefill (multi-token) and decode (single-token) paths
  • Inputs: Q, K, V, decay (alpha), beta gating, optional initial recurrent state
  • Outputs: attention output, updated recurrent state
  • CUDA implementation with per-head parallelism; CPU implementation with Eigen

CausalConvWithState — Implements causal 1D convolution with persistent state for autoregressive decoding:

  • Supports prefill (full convolution) and decode (state-based sliding window)
  • Inputs: input tensor, conv weights, optional bias, optional initial conv state
  • Outputs: convolution output, updated conv state

Op Definitions

  • Registered in com.microsoft domain (opset 1)
  • Full shape inference and type constraints in bert_defs.cc

Testing

  • Parity test (test_parity_linear_attention_causal_conv.py) validates CUDA and CPU kernels against PyTorch reference implementations from the FLA (Flash Linear Attention) library

apsonawane and others added 3 commits March 31, 2026 02:47
…ith 'import' and 'import from'

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds new com.microsoft contrib operators and implementations to support efficient inference for Qwen3.5-style hybrid decoder models in ONNX Runtime.

Changes:

  • Adds operator schemas for LinearAttention and CausalConvWithState (opset 1, com.microsoft) including type/shape inference.
  • Adds CUDA kernels for LinearAttention and CausalConvWithState, and registers them in the CUDA contrib registry.
  • Adds CPU kernels for LinearAttention and CausalConvWithState, registers them in the CPU contrib registry, and introduces a Python parity test against PyTorch.

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py Adds CUDA/CPU parity tests comparing ORT outputs vs PyTorch reference implementations.
onnxruntime/core/graph/contrib_ops/ms_opset.h Exposes new schemas in the Microsoft opset v1 registry.
onnxruntime/core/graph/contrib_ops/bert_defs.cc Defines schemas + docstrings + type/shape inference for both new ops.
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc Registers CUDA kernels for the two new ops.
onnxruntime/contrib_ops/cuda/bert/linear_attention.h Declares CUDA LinearAttention kernel class.
onnxruntime/contrib_ops/cuda/bert/linear_attention.cc Implements CUDA LinearAttention kernel wiring and launch.
onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.h Declares CUDA LinearAttention kernel launcher.
onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu Implements CUDA fused recurrent linear attention kernel(s).
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h Declares CUDA CausalConvWithState kernel class.
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc Implements CUDA CausalConvWithState kernel wiring and launch.
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h Declares CUDA causal conv kernel launcher.
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu Implements CUDA fused causal depthwise conv1d + state management (+ SiLU).
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc Registers CPU kernels for the two new ops.
onnxruntime/contrib_ops/cpu/bert/linear_attention.h Declares CPU LinearAttention kernel class.
onnxruntime/contrib_ops/cpu/bert/linear_attention.cc Implements CPU LinearAttention kernel (thread-parallel over heads).
onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h Declares CPU CausalConvWithState kernel class.
onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc Implements CPU CausalConvWithState for ndim=1 with decode/prefill paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/linear_attention.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/linear_attention.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@justinchuby
Copy link
Copy Markdown
Contributor

CausalConvWithState: It is only 1d? I recall we also need 2d for nemotron.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two correctness issues need fixing before merge; remaining items are suggestions/nitpicks.

High priority:

  1. Divergent __syncthreads() in CausalConvPrefillKernelBatched (see inline)
  2. Silent numeric corruption when max(d_k, d_v) > max_threads_per_block in the generic LinearAttentionRecurrentKernel (see inline)

Suggestions / nitpicks: see inline comments on chunk_size_, CPU fp16 registration, redundant cudaMemsetAsync, and test coverage gaps.

Comment thread onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/linear_attention.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Comment thread onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py Outdated
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up review: all 6 prior comments addressed. 3 new suggestions remain (shared memory validation, inverse GQA test, per-key-dim decay test coverage).

Comment thread onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu
Comment thread onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py Outdated
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One new correctness issue found (generic kernel s_scratch race). Previous three threads (smem validation, inverse GQA test, per-key-dim decay test) are all addressed and have been resolved above.


⚠️ High: Inter-warp race on s_scratch in generic LinearAttentionRecurrentKernel

In LinearAttentionRecurrentKernel (the fallback for head dims not in {64,128}), the standalone if (needs_retrieval) block stores k into s_scratch and then overwrites it with retrieval results — without using the separately-allocated k_buf:

if (needs_retrieval) {
    if (tid < d_k) {
        s_scratch[tid] = kt_val;  // ← k stored in s_scratch
    }
    __syncthreads();
    if (tid < d_v) {
        float acc = 0.0f;
        for (int i = 0; i < d_k; ++i)
            acc += S_smem[i * d_v + tid] * s_scratch[i];
        s_scratch[tid] = acc;  // ← overwrites s_scratch[tid] — races with other warps reading s_scratch[tid]
    }
    __syncthreads();
}

When d_k > 32 (multiple warps), warp W1 can finish its inner-product loop and write s_scratch[32..63] before warp W0 reaches loop index i=32, causing W0 to read corrupted k values. The shared memory layout already has k_buf (smem + d_k * d_v) allocated for exactly this purpose — it just isn't used here. The fixed-shape kernel avoids the race by consistently using k_buf.

Fix (mirror the fixed-shape kernel):

if (needs_retrieval) {
    if (tid < d_k) {
        k_buf[tid] = kt_val;  // k_buf is already allocated at smem + d_k*d_v
    }
    __syncthreads();
    if (tid < d_v) {
        float acc = 0.0f;
        for (int i = 0; i < d_k; ++i)
            acc += S_smem[i * d_v + tid] * k_buf[i];
        s_scratch[tid] = acc;
    }
    __syncthreads();
}

Affected rule: delta (needs_retrieval=true, needs_decay=false; the fused path that correctly uses k_buf is not taken for this rule). gated_delta is unaffected (fused path). linear/gated are unaffected (needs_retrieval=false).

Practical exposure: Qwen3.5 (d_k=d_v=64/128) hits the fixed-shape kernels — this race is only triggered by callers using the delta rule with non-standard head dimensions.


Remaining items (suggestions/nitpicks) captured in PR_27907_REVIEW_48ab7e9_claude_1.md in the workspace.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One new high-priority finding: inter-warp race on s_scratch in the generic LinearAttentionRecurrentKernel. Previous three threads (smem validation, inverse GQA test, per-key-dim decay test) are addressed and resolved. Full details in PR_27907_REVIEW_48ab7e9_claude_1.md.

High: In LinearAttentionRecurrentKernel (fallback for head dims not in {64,128}), the standalone needs_retrieval block stores k into s_scratch then overwrites it with retrieval results. When d_k > 32, warp W1 can finish its inner-product loop and write s_scratch[32..63] before warp W0 reaches loop index i=32, causing silent numeric corruption. The fix is to use k_buf (already allocated at smem + d_k*d_v) as the fixed-shape kernel does. Affected rule: delta with non-standard head dimensions.

Comment thread onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up review on commit cdf987d:

All findings from both prior rounds are now addressed:

Resolved in this commit (cdf987d):

  • ✅ Inter-warp race on s_scratch in generic LinearAttentionRecurrentKernel: fixed by switching to k_buf in the standalone needs_retrieval path. Inline comment posted on the fix.

Resolved in prior commits:

  • ✅ Divergent __syncthreads() in CausalConvPrefillKernelBatched
  • ✅ Silent overflow when max(d_k, d_v) > max_threads_per_block in generic kernel
  • chunk_size_ stored but never used (removed / documented)
  • ✅ Missing cudaFuncSetAttribute for >48 KB shared memory in both prefill kernels
  • test_linear_attention_inverse_gqa now uses q=8, kv=16 (true inverse GQA)
  • ✅ Per-key-dim decay path now covered by test_linear_attention_per_key_dim_decay
  • ✅ CPU fp16 schema vs kernel registration gap documented in comments
  • ✅ All four update rules (linear, gated, delta, gated_delta) tested on CPU and CUDA

Remaining (suggestions — not blocking):

  • Lintrunner RUFF formatting on the test file still open (threads 3–4 from github-advanced-security); run lintrunner -a to auto-fix.
  • Consider documenting the seq_len <= 128 threshold in LaunchCausalConvWithStateKernel to explain the occupancy trade-off for the batched vs. single-channel kernel selection.

LGTM otherwise.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All concerns addressed. The inter-warp race on s_scratch (the last high-priority item) is fixed in cdf987d. Approving.

@apsonawane apsonawane enabled auto-merge (squash) April 7, 2026 18:11
@apsonawane apsonawane merged commit 0fedb26 into main Apr 7, 2026
99 of 100 checks passed
@apsonawane apsonawane deleted the asonawane/linearattention branch April 7, 2026 19:49
sanaa-hamel-microsoft added a commit that referenced this pull request Apr 24, 2026
Version bump to 1.25.1.

This cherry-picks the following commits for the release:

| Commit ID | PR Number | Commit Title |
|-----------|-----------|-------------|
| e532c21 | #27842 | linear attention signature |
| 410f5a8 | #27752 | +rotemb, +rmsnorm, reshape->opset-25,
transpose->opset-24 |
| 0fedb26 | #27907 | Add LinearAttention and CausalConvState ops for
Qwen3.5 |
| 3ac6040 | #27996 | webgpu support for qwen3.5 |
| c36c422 | #27998 | [WebGPU EP] Fuse QMoE 1-token decode path to
reduce GPU dispatches |
| 94f32ec | #27289 | [CORE]: Improve filesystem error messages during
Linux device discovery |
| dce77a3 | #28118 | Fix lack of auth on python packaging |

---------

Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: eserscor <erscor@microsoft.com>
Co-authored-by: Sanaa Hamel <sanaahamel@microsoft.com>
Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
Co-authored-by: Stephan Seitz <sseitz@nvidia.com>
Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants