Skip to content

linear attention signature#27842

Merged
justinchuby merged 10 commits intomainfrom
gs/linear-attention-signature
Apr 6, 2026
Merged

linear attention signature#27842
justinchuby merged 10 commits intomainfrom
gs/linear-attention-signature

Conversation

@guschmue
Copy link
Copy Markdown
Contributor

@guschmue guschmue commented Mar 25, 2026

Proposal for CausalConvWithState and LinearAttention onnxruntime custom operator.
This follows the proposal in onnx/onnx#7767.

@guschmue guschmue changed the title Gs/linear attention signature linear attention signature Mar 25, 2026
@guschmue
Copy link
Copy Markdown
Contributor Author

a working end to end implementation for webgpu with this signature can be found here:
https://github.com/microsoft/onnxruntime/tree/gs/wgpu-lattn

Possible changes at this point:

  1. maybe inputs should be transposed, else the model will have transpose operators in front of LinearAttention
  2. maybe CausalConv1DWithState should be CausalConvWithState

Comment thread onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc Fixed
Comment thread onnxruntime/test/contrib_ops/linear_attention_op_test.cc Fixed
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/linear_attention_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/linear_attention_op_test.cc
Comment thread onnxruntime/test/contrib_ops/linear_attention_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/linear_attention_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/linear_attention_op_test.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
guschmue added a commit that referenced this pull request Mar 31, 2026
@justinchuby
Copy link
Copy Markdown
Contributor

Cross-referencing with the ONNX proposal (onnx/onnx#7767) — review by AI agent team

I compared the LinearAttention and CausalConvWithState schemas between this ORT contrib PR and the ONNX proposal in onnx/onnx#7767. Summary below.


LinearAttention — comparison

✅ Matches

Item Both PRs
Input names (0–5) query, key, value, past_state, decay, beta — identical order and names
Output names (0–1) output, present_state — identical
Input shapes All match: 3D packed [B,T,H*D] for Q/K/V/decay/beta; 4D [B,H_kv,d_k,d_v] for state
Attribute names q_num_heads, kv_num_heads, update_rule, scale, chunk_size — all five present
Attribute types & defaults update_rule="gated_delta", scale=0.0, chunk_size=64, both heads required
TypeConstraint T {tensor(float), tensor(float16), tensor(bfloat16)} — identical
Mathematical semantics linear, gated, delta, gated_delta update rules described identically
Optional input semantics past_state defaults to zeros; decay/beta required by respective modes
Namespace ORT uses com.microsoft, ONNX proposal uses ai.onnxexpected and correct per the 3-phase adoption path

❌ Key mismatch — state type precision (S vs T)

The ONNX proposal introduces a second type parameter S for the recurrent state tensors, following the stash_type convention used by LayerNormalization and GroupNormalization:

  • past_state → type S (not T)
  • present_state → type S (not T)
  • S is constrained to float32, or the same as T

This ORT PR uses a single type T for all inputs/outputs including state. This means:

  • In this ORT schema, running with T=float16 forces the recurrent state to also be float16
  • The ONNX proposal explicitly supports T=float16, S=float32 — float16 activations with float32 state accumulation — which is important for numerical stability during long-sequence decoding (the state matrix is accumulated over hundreds or thousands of tokens)
  • The ONNX proposal notes: "Using S = float32 with T = float16/bfloat16 is the recommended configuration for long sequences; runtimes handle any necessary casting internally"

Recommendation: Consider adding a stash_type attribute (integer, default 1 = float32) analogous to LayerNormalization, so callers can opt into float32 state accumulation independently of the activation dtype. This would align with the ONNX proposal and avoid a breaking schema change later.

⚠️ Minor gap — input validation rules not in doc string

The ONNX proposal has an explicit table of required/forbidden optional inputs per update_rule:

update_rule decay beta
"linear" must be omitted must be omitted
"gated" required must be omitted
"delta" must be omitted required
"gated_delta" required required

The ORT doc string describes the semantics but doesn't explicitly state that providing a forbidden input is a model validation error. Worth adding to the doc for clarity — helps implementors know they should validate at model-load time, not silently ignore extra inputs.


CausalConvWithState — comparison

✅ Full match

All inputs (input, weight, bias, past_state), outputs (output, present_state), attributes (ndim, activation), type constraints, and shapes are identical between the two PRs. No differences found.


Summary

Op Status
LinearAttention Mostly aligned — 1 structural gap (state type S vs T), 1 minor doc gap
CausalConvWithState Fully aligned

The state precision gap is the only item that would cause a future breaking schema change if not addressed here. Everything else looks well-aligned with the ONNX proposal.

@justinchuby
Copy link
Copy Markdown
Contributor

Cross-referencing with the ONNX proposal (onnx/onnx#7767) — review by AI agent team

I systematically compared the LinearAttention and CausalConvWithState op schemas in this PR against the formal schema defined in onnx/onnx#7767. The namespace difference (com.microsoft here vs ai.onnx in the proposal) is expected per the contrib-first adoption path and is not flagged below.


LinearAttention

✅ Matches

Element ORT ONNX proposal
Input names query, key, value, past_state, decay, beta same
Input order 0..5 same
Input optionality past_state, decay, beta optional same
Input shapes (3D packed) (B, T, H*D) for Q/K/V; (B, H_kv, d_k, d_v) for state same
Decay shape variants (B, T, H_kv*d_k) or (B, T, H_kv) same
Output names output, present_state same
Output shapes (B, T, H_q*d_v) and (B, H_kv, d_k, d_v) same
Attribute names update_rule, scale, q_num_heads, kv_num_heads, chunk_size same
Attribute types all match same
Attribute defaults update_rule="gated_delta", scale=0.0, chunk_size=64 same
q_num_heads/kv_num_heads required no default value same
Activation dtype constraint float, float16, bfloat16 same
Zero-initialized state semantics implied (optional input) explicitly stated

❌ Mismatch: State type parameter (S vs single T)

This is the most significant schema difference.

The ONNX proposal defines two separate type parameters:

  • T — activation dtype for query, key, value, decay, beta, output
  • S — state dtype for past_state and present_state; must be float32 or same as T

This allows the recommended configuration of fp32 state + fp16/bf16 activations — important for numerical accuracy in long sequences where state accumulation in fp16 diverges. The ONNX proposal explicitly calls this out: "Using S = float32 with T = float16/bfloat16 is the recommended configuration for long sequences."

This PR uses a single T type for all inputs including state, which prevents expressing this mixed-precision configuration in the op signature.

Suggestion: Consider adding a second type constraint S for past_state/present_state, or at minimum add an attribute state_dtype that allows the runtime to accumulate state in float32 even when activations are fp16.

⚠️ Minor: Input combination validation not documented in schema

The ONNX proposal formally specifies that the presence/absence of decay and beta depends on update_rule and must be validated at model-load time (e.g., "linear" requires neither, "gated_delta" requires both; providing a forbidden input is a schema error).

This PR marks both as Optional in the schema without documenting this constraint. The kernel may still validate at runtime, but adding a note to the docstring would make the contract explicit to model builders.


CausalConvWithState

✅ Matches (essentially identical)

Element ORT ONNX proposal
Input names input, weight, bias, past_state same
Input shapes (B, C, ...) input, (C, 1, k, ...) weight, (C,) bias, (B, C, k-1) state same
Input optionality bias and past_state optional same
Output names output, present_state same
Output shapes same as input; same as past_state same
Attributes activation (default "none"), ndim (default 1) same
Activation values "silu", "swish", "none" ("silu" and "swish" are aliases) same
Type constraint single T: float, float16, bfloat16 same

No divergence found for CausalConvWithState. The single-type constraint is appropriate here since the conv state dtype naturally matches the input dtype.


Summary

The two schemas are structurally aligned. The one actionable difference is the missing S state-dtype type parameter in LinearAttention. Everything else matches: all 5 attributes, all 6 inputs, both outputs, all defaults, and the full CausalConvWithState schema.

The ONNX proposal is at onnx/onnx#7767 if you want to review the reference-level pseudocode and the formal input-combination validation table.

justinchuby
justinchuby previously approved these changes Mar 31, 2026
Copy link
Copy Markdown
Contributor

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ agreement w/ the AI comments

Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc Outdated
Comment thread onnxruntime/core/graph/contrib_ops/bert_defs.cc
justinchuby
justinchuby previously approved these changes Apr 2, 2026
@justinchuby justinchuby merged commit e532c21 into main Apr 6, 2026
98 of 99 checks passed
@justinchuby justinchuby deleted the gs/linear-attention-signature branch April 6, 2026 18:23
sanaa-hamel-microsoft pushed a commit that referenced this pull request Apr 21, 2026
Proposal for CausalConvWithState and LinearAttention onnxruntime custom
operator.
This follows the proposal in onnx/onnx#7767.
sanaa-hamel-microsoft added a commit that referenced this pull request Apr 24, 2026
Version bump to 1.25.1.

This cherry-picks the following commits for the release:

| Commit ID | PR Number | Commit Title |
|-----------|-----------|-------------|
| e532c21 | #27842 | linear attention signature |
| 410f5a8 | #27752 | +rotemb, +rmsnorm, reshape->opset-25,
transpose->opset-24 |
| 0fedb26 | #27907 | Add LinearAttention and CausalConvState ops for
Qwen3.5 |
| 3ac6040 | #27996 | webgpu support for qwen3.5 |
| c36c422 | #27998 | [WebGPU EP] Fuse QMoE 1-token decode path to
reduce GPU dispatches |
| 94f32ec | #27289 | [CORE]: Improve filesystem error messages during
Linux device discovery |
| dce77a3 | #28118 | Fix lack of auth on python packaging |

---------

Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: eserscor <erscor@microsoft.com>
Co-authored-by: Sanaa Hamel <sanaahamel@microsoft.com>
Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
Co-authored-by: Stephan Seitz <sseitz@nvidia.com>
Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants