Skip to content

ORT 1.25.1 release: version bump and cherry-pick #27907#28149

Merged
sanaa-hamel-microsoft merged 8 commits intorel-1.25.1from
vraspar/bump-version-1.25.1
Apr 24, 2026
Merged

ORT 1.25.1 release: version bump and cherry-pick #27907#28149
sanaa-hamel-microsoft merged 8 commits intorel-1.25.1from
vraspar/bump-version-1.25.1

Conversation

@vraspar
Copy link
Copy Markdown
Contributor

@vraspar vraspar commented Apr 21, 2026

Version bump to 1.25.1.

This cherry-picks the following commits for the release:

Commit ID PR Number Commit Title
e532c21 #27842 linear attention signature
410f5a8 #27752 +rotemb, +rmsnorm, reshape->opset-25, transpose->opset-24
0fedb26 #27907 Add LinearAttention and CausalConvState ops for Qwen3.5
3ac6040 #27996 webgpu support for qwen3.5
c36c422 #27998 [WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches
94f32ec #27289 [CORE]: Improve filesystem error messages during Linux device discovery
dce77a3 #28118 Fix lack of auth on python packaging

vraspar and others added 2 commits April 21, 2026 02:24
Adds custom CUDA and CPU kernels for linear attention and causal 1D
convolution with state, enabling efficient inference of Qwen3.5 hybrid
decoder models in ONNX Runtime.

### New Operators

**`LinearAttention`** — Implements the GatedDeltaNet recurrent linear
attention mechanism:
- Fused kernel computing gated delta-rule update of a recurrent state
matrix
- Supports both prefill (multi-token) and decode (single-token) paths
- Inputs: Q, K, V, decay (alpha), beta gating, optional initial
recurrent state
- Outputs: attention output, updated recurrent state
- CUDA implementation with per-head parallelism; CPU implementation with
Eigen

**`CausalConvWithState`** — Implements causal 1D convolution with
persistent state for autoregressive decoding:
- Supports prefill (full convolution) and decode (state-based sliding
window)
- Inputs: input tensor, conv weights, optional bias, optional initial
conv state
- Outputs: convolution output, updated conv state

### Op Definitions
- Registered in `com.microsoft` domain (opset 1)
- Full shape inference and type constraints in `bert_defs.cc`

### Testing
- Parity test (`test_parity_linear_attention_causal_conv.py`) validates
CUDA and CPU kernels against PyTorch reference implementations from the
FLA (Flash Linear Attention) library

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
@sanaa-hamel-microsoft
Copy link
Copy Markdown
Contributor

We'll need to cherry-pick #28118 as well for pkging.

### Description

Add feed authentication to additional ADO pipelines.

### Motivation and Context

CFS clean policy requires it.

---------

Co-authored-by: Sanaa Hamel <sanaahamel@microsoft.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Prepares the ORT 1.25.1 patch release by bumping version numbers, cherry-picking new Qwen3.5-related contrib ops (LinearAttention + CausalConvWithState) with CPU/CUDA kernels and parity tests, and applying CI pipeline feed-auth fixes.

Changes:

  • Bump ORT version from 1.25.0 → 1.25.1 across C/C++/Python/JS/docs.
  • Add contrib op kernels (CPU + CUDA) and a PyTorch parity test for LinearAttention and CausalConvWithState.
  • Update Azure Pipelines templates to authenticate package feeds and normalize python “architecture” handling.

Reviewed changes

Copilot reviewed 35 out of 39 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tools/ci_build/github/azure-pipelines/templates/setup-feeds-and-python-steps.yml Map x86_64x64 for UsePythonVersion architecture handling
tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml Invoke feed+python setup template in CUDA wheel test job
tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml Invoke feed+python setup template in CPU wheel test job
tools/ci_build/github/azure-pipelines/templates/py-linux.yml Invoke feed+python setup template in Linux wheel build job
tools/ci_build/github/azure-pipelines/stages/py-linux-webgpu-stage.yml Invoke feed+python setup template in WebGPU stage
tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml Invoke feed+python setup template in GPU stage
onnxruntime/test/python/transformers/test_parity_linear_attention_causal_conv.py New CUDA/CPU parity coverage vs PyTorch reference implementations
onnxruntime/core/session/onnxruntime_c_api.cc Update version static_assert to 1.25.1
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc Register CUDA kernels for LinearAttention + CausalConvWithState
onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.h CUDA kernel launcher declaration for LinearAttention
onnxruntime/contrib_ops/cuda/bert/linear_attention_impl.cu CUDA fused recurrent LinearAttention implementation
onnxruntime/contrib_ops/cuda/bert/linear_attention.h CUDA kernel class declaration for LinearAttention
onnxruntime/contrib_ops/cuda/bert/linear_attention.cc CUDA kernel implementation for LinearAttention
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.h CUDA kernel launcher declaration for CausalConvWithState
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu CUDA fused causal depthwise conv + state implementation
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.h CUDA kernel class declaration for CausalConvWithState
onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state.cc CUDA kernel implementation for CausalConvWithState
onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc Register CPU kernels for LinearAttention + CausalConvWithState
onnxruntime/contrib_ops/cpu/bert/linear_attention.h CPU kernel class declaration for LinearAttention
onnxruntime/contrib_ops/cpu/bert/linear_attention.cc CPU kernel implementation for LinearAttention
onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.h CPU kernel class declaration for CausalConvWithState
onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc CPU kernel implementation for CausalConvWithState
onnxruntime/init.py Python package version bump to 1.25.1
js/web/package.json JS web package version bump to 1.25.1
js/web/package-lock.json JS web lockfile version bump to 1.25.1
js/web/lib/version.ts JS web generated version bump to 1.25.1
js/react_native/package.json React Native package version bump to 1.25.1
js/react_native/package-lock.json React Native lockfile version bump to 1.25.1
js/react_native/lib/version.ts React Native generated version bump to 1.25.1
js/node/script/install-metadata-versions.js Node install metadata version bump to 1.25.1
js/node/package.json Node package version bump to 1.25.1
js/node/package-lock.json Node lockfile version bump to 1.25.1
js/node/lib/version.ts Node generated version bump to 1.25.1
js/common/package.json JS common package version bump to 1.25.1
js/common/package-lock.json JS common lockfile version bump to 1.25.1
js/common/lib/version.ts JS common generated version bump to 1.25.1
docs/python/README.rst Add 1.25.1 entry / release link
docs/OperatorKernels.md Add kernel table entries for new ops
VERSION_NUMBER Root version bump to 1.25.1
Files not reviewed (4)
  • js/common/package-lock.json: Language not supported
  • js/node/package-lock.json: Language not supported
  • js/react_native/package-lock.json: Language not supported
  • js/web/package-lock.json: Language not supported

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cuda/bert/linear_attention.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/linear_attention.cc
Comment thread onnxruntime/contrib_ops/cpu/bert/linear_attention.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/causal_conv_with_state_impl.cu
Comment thread onnxruntime/contrib_ops/cpu/bert/causal_conv_with_state.cc
guschmue and others added 5 commits April 21, 2026 14:22
Proposal for CausalConvWithState and LinearAttention onnxruntime custom
operator.
This follows the proposal in onnx/onnx#7767.
…ry (#27289)

### Description
<!-- Describe your changes. -->

This is a follow-up to
#26210
to address
#26210 (comment)
and review dog lints.

ErrorCodeToStatus currently does not include the filesystem path that
caused the error. This could it make difficult to know the root cause
of a reported filesystem error.

Review dog lints:
https://github.com/microsoft/onnxruntime/pull/26210/changes
Plus a typo: `dit` -> `did`


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Clean up discussed issues and lints of #26210
for webgpu ep:
+ onnx rotary-embedding op
+ onnx rmsnorm
+ reshape-> opset-25
+ transpose -> opset-24
webgpu support for qwen3.5, adding LinearAttention and
CausalConvWithState ops based on this proposal:
from onnx/onnx#7767

The model can be created with model builder from
https://github.com/microsoft/onnxruntime-genai/blob/main/src/python/py/models/builder.py.

For example for the text only flavor:
```
python builder.py -m Qwen/Qwen3.5-0.8B  -o Qwen3.5-0.8B -e webgpu -p int4 --extra_options int4_accuracy_level=4 exclude_embeds=False
```
…27998)

Fuse the QMoE 1-token decode path to reduce GPU dispatches from 17 (1 +
k×4) to 5 (gate + fc1 + swiglu + fc2 + mix), improving token generation
throughput by ~21% on Meteor Lake for the gpt-oss-20b MoE model (19 → 23
tps).
The QMoE operator processes Mixture-of-Experts layers by selecting top-k
experts (k=4) per token. In the original 1-token decode path, each
expert is processed serially with 4 dispatches (gather + fc1 + swiglu +
fc2 + mix), totaling 17 GPU dispatches per QMoE call. Since each
dispatch has M=1, the GPU is underutilized and CPU dispatch overhead
dominates.
For the 1-token path (num_rows == 1):

**Gate1Token** — Select top-k experts and output an
[indirect_experts](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html)
buffer mapping row index → expert index
**Batched fc1 MatMulNBits** — Run a single M=k matmul with
[per_row_weight_indirect](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html)
mode, where each row selects a different expert's weights via the
indirect buffer
**SwiGLU** — Apply activation on all k rows at once
**Batched fc2 MatMulNBits** — Same per-row expert selection for the down
projection
**FusedFinalMix** — Accumulate all k weighted expert results into the
output
Fuse Batched fc1 MatMulNBits + SwiGLU

Fuse Batched fc2 MatMulNBits + FusedFinalMix

Finally, we only need three shaders: Gate1Token, fused Batched fc1
MatMulNBits, fused batched fc2 MatMulNBits.
@sanaa-hamel-microsoft sanaa-hamel-microsoft merged commit 6fd52e4 into rel-1.25.1 Apr 24, 2026
92 of 100 checks passed
@sanaa-hamel-microsoft sanaa-hamel-microsoft deleted the vraspar/bump-version-1.25.1 branch April 24, 2026 14:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants