linear attention signature#27842
Conversation
|
a working end to end implementation for webgpu with this signature can be found here: Possible changes at this point:
|
|
Cross-referencing with the ONNX proposal (onnx/onnx#7767) — review by AI agent team I compared the
|
| Item | Both PRs |
|---|---|
| Input names (0–5) | query, key, value, past_state, decay, beta — identical order and names |
| Output names (0–1) | output, present_state — identical |
| Input shapes | All match: 3D packed [B,T,H*D] for Q/K/V/decay/beta; 4D [B,H_kv,d_k,d_v] for state |
| Attribute names | q_num_heads, kv_num_heads, update_rule, scale, chunk_size — all five present |
| Attribute types & defaults | update_rule="gated_delta", scale=0.0, chunk_size=64, both heads required |
| TypeConstraint T | {tensor(float), tensor(float16), tensor(bfloat16)} — identical |
| Mathematical semantics | linear, gated, delta, gated_delta update rules described identically |
| Optional input semantics | past_state defaults to zeros; decay/beta required by respective modes |
| Namespace | ORT uses com.microsoft, ONNX proposal uses ai.onnx — expected and correct per the 3-phase adoption path |
❌ Key mismatch — state type precision (S vs T)
The ONNX proposal introduces a second type parameter S for the recurrent state tensors, following the stash_type convention used by LayerNormalization and GroupNormalization:
past_state→ typeS(notT)present_state→ typeS(notT)Sis constrained to float32, or the same asT
This ORT PR uses a single type T for all inputs/outputs including state. This means:
- In this ORT schema, running with
T=float16forces the recurrent state to also be float16 - The ONNX proposal explicitly supports
T=float16, S=float32— float16 activations with float32 state accumulation — which is important for numerical stability during long-sequence decoding (the state matrix is accumulated over hundreds or thousands of tokens) - The ONNX proposal notes: "Using S = float32 with T = float16/bfloat16 is the recommended configuration for long sequences; runtimes handle any necessary casting internally"
Recommendation: Consider adding a stash_type attribute (integer, default 1 = float32) analogous to LayerNormalization, so callers can opt into float32 state accumulation independently of the activation dtype. This would align with the ONNX proposal and avoid a breaking schema change later.
⚠️ Minor gap — input validation rules not in doc string
The ONNX proposal has an explicit table of required/forbidden optional inputs per update_rule:
update_rule |
decay |
beta |
|---|---|---|
"linear" |
must be omitted | must be omitted |
"gated" |
required | must be omitted |
"delta" |
must be omitted | required |
"gated_delta" |
required | required |
The ORT doc string describes the semantics but doesn't explicitly state that providing a forbidden input is a model validation error. Worth adding to the doc for clarity — helps implementors know they should validate at model-load time, not silently ignore extra inputs.
CausalConvWithState — comparison
✅ Full match
All inputs (input, weight, bias, past_state), outputs (output, present_state), attributes (ndim, activation), type constraints, and shapes are identical between the two PRs. No differences found.
Summary
| Op | Status |
|---|---|
LinearAttention |
Mostly aligned — 1 structural gap (state type S vs T), 1 minor doc gap |
CausalConvWithState |
Fully aligned ✅ |
The state precision gap is the only item that would cause a future breaking schema change if not addressed here. Everything else looks well-aligned with the ONNX proposal.
|
Cross-referencing with the ONNX proposal (onnx/onnx#7767) — review by AI agent team I systematically compared the LinearAttention✅ Matches
❌ Mismatch: State type parameter (
|
| Element | ORT | ONNX proposal |
|---|---|---|
| Input names | input, weight, bias, past_state |
same |
| Input shapes | (B, C, ...) input, (C, 1, k, ...) weight, (C,) bias, (B, C, k-1) state |
same |
| Input optionality | bias and past_state optional |
same |
| Output names | output, present_state |
same |
| Output shapes | same as input; same as past_state | same |
| Attributes | activation (default "none"), ndim (default 1) |
same |
| Activation values | "silu", "swish", "none" ("silu" and "swish" are aliases) |
same |
| Type constraint | single T: float, float16, bfloat16 |
same |
No divergence found for CausalConvWithState. The single-type constraint is appropriate here since the conv state dtype naturally matches the input dtype.
Summary
The two schemas are structurally aligned. The one actionable difference is the missing S state-dtype type parameter in LinearAttention. Everything else matches: all 5 attributes, all 6 inputs, both outputs, all defaults, and the full CausalConvWithState schema.
The ONNX proposal is at onnx/onnx#7767 if you want to review the reference-level pseudocode and the formal input-combination validation table.
justinchuby
left a comment
There was a problem hiding this comment.
LGTM w/ agreement w/ the AI comments
Proposal for CausalConvWithState and LinearAttention onnxruntime custom operator. This follows the proposal in onnx/onnx#7767.
Version bump to 1.25.1. This cherry-picks the following commits for the release: | Commit ID | PR Number | Commit Title | |-----------|-----------|-------------| | e532c21 | #27842 | linear attention signature | | 410f5a8 | #27752 | +rotemb, +rmsnorm, reshape->opset-25, transpose->opset-24 | | 0fedb26 | #27907 | Add LinearAttention and CausalConvState ops for Qwen3.5 | | 3ac6040 | #27996 | webgpu support for qwen3.5 | | c36c422 | #27998 | [WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches | | 94f32ec | #27289 | [CORE]: Improve filesystem error messages during Linux device discovery | | dce77a3 | #28118 | Fix lack of auth on python packaging | --------- Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: eserscor <erscor@microsoft.com> Co-authored-by: Sanaa Hamel <sanaahamel@microsoft.com> Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com> Co-authored-by: Stephan Seitz <sseitz@nvidia.com> Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Proposal for CausalConvWithState and LinearAttention onnxruntime custom operator.
This follows the proposal in onnx/onnx#7767.