[WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches#27998
[WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches#27998
Conversation
For MoE models with k=4 top experts and 32 total experts, the previous 1-token decode path dispatched 17 GPU programs per QMoE call (1 gate + 4 experts x (MatMul_fc1 + SwiGLU + MatMul_fc2 + FinalMix)). Each expert was processed sequentially with separate dispatches. This change fuses the 1-token path into 6 dispatches: 1. Gate (select top-k experts) 2. Replicate hidden_state k times 3. Single batched MatMul_fc1 with M=k, per-row expert selection 4. Single SwiGLU on all k rows 5. Single batched MatMul_fc2 with M=k, per-row expert selection 6. Single fused FinalMix for all k experts Key changes: - Add per_row_weight_indirect mode to MatMulNBitsProgram: each M-row reads its expert index from weight_index_indirect[a_global] instead of using a single uniform weight_index for all rows. This allows one MatMul dispatch to process k different experts simultaneously. - Skip DP4A/WideTile/Subgroup optimized paths when per_row mode is active (they don't support per-row expert selection yet). - Add ReplicateHiddenProgram shader to copy 1 hidden row to k rows. - Add FusedFinalMix1TokenProgram shader to accumulate all k expert results in one dispatch. Verified: Model produces correct coherent text output.
…x fused QMoE correctness - Add per_row_weight_indirect to DP4AMatMulNBitsSmallMProgram for fused MoE per-row expert selection in the DP4A SmallM shader path - Fix has_weight_idx in ApplyDP4AMatrixMatMulNBits to account for weight_index_indirect (was only checking weight_index != 0) - Force SmallM path when per_row_weight_indirect is true (Regular DP4A does not support per-row) - Skip WideTile path when per_row_weight_indirect is true (no per-row support) - Fix FusedFinalMix1Token race condition: dispatch across hidden_size instead of k, each thread accumulates all k experts to avoid concurrent writes to the same output locations
# Conflicts: # onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc # onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc # onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
- Add override_M param to ApplyMatMulNBits; compute dispatch_M and broadcast_a at the top level instead of overriding M - Shaders use broadcast_a_row (decoupled from per_row_weight_indirect) to read A row 0 when dispatch_M > M - DP4A quantize uses actual M for quantization; SmallM dispatch uses dispatch_M - fc1 in QMoE passes override_M=k (A is 1 row, broadcast); fc2 passes 0 (A has k real rows) - Eliminates ReplicateHidden dispatch and (k, hidden_size) allocation
- has_weight_idx_indirect now always means per-row expert selection (weight_index_indirect[a_global]) in shaders, removing the 3-way #if - Removed per_row_weight_indirect_ from MatMulNBitsProgram and DP4AMatMulNBitsSmallMProgram classes - Removed per_row_weight_indirect param from ApplyMatMulNBits and ApplyDP4AMatrixMatMulNBits - Routing uses has_weight_idx_indirect to skip WideTile/Subgroup and force DP4A SmallM path
There was a problem hiding this comment.
Pull request overview
This PR optimizes the WebGPU QMoE operator’s 1-token decode fast path by batching the top‑k experts’ FC1/activation/FC2 work and replacing the per‑expert final mix with a single fused mix dispatch, aiming to reduce GPU dispatch overhead and improve generation throughput.
Changes:
- Add an
override_M/broadcast capability to WebGPUMatMulNBits(and DP4A small‑M path) to support dispatchingM=kwhile reading A from a single row. - Rework QMoE
num_rows == 1path to: Gate1Token → batched fc1 → batched SwiGLU → batched fc2 → fused final mix shader. - Introduce
fused_final_mix_1token.wgsl.templateand remove the old per‑expertfinal_mix_1token.wgsl.template.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template | Adds broadcast_a_row shader param to allow A-row broadcasting for override_M dispatch. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h | Extends MatMulNBitsProgram and ApplyMatMulNBits API to support broadcast/override_M. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | Wires broadcast_a_row into shader generation and uses dispatch_M for dispatch sizing/uniforms. |
| onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template | Adds broadcast_a_row support and switches indirect weight indexing to be per output row. |
| onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h | Updates DP4A matmul API to accept dispatch_M and adds broadcast flag to program. |
| onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | Threads dispatch_M through small‑M DP4A path and adds shader param plumbing. |
| onnxruntime/contrib_ops/webgpu/moe/qmoe.cc | Implements the fused 1-token QMoE path using batched MatMulNBits and a new fused final mix program. |
| onnxruntime/contrib_ops/webgpu/moe/fused_final_mix_1token.wgsl.template | New shader that accumulates all k expert outputs for a single token in one dispatch. |
| onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template | Removes the old per‑expert final mix shader used by the previous 1-token path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…token path - Add dispatch_M as a separate uniform instead of overloading uniforms.M. uniforms.M now always reflects the real A row count (correct batch stride), while uniforms.dispatch_M controls grid decomposition and output addressing. This fixes the potential out-of-bounds issue when batch_count > 1. - Skip ZeroTensorProgram for the 1-token fused path since FusedFinalMix writes all output elements directly (saves 1 GPU dispatch). - Fix misleading comment in fused_final_mix_1token.wgsl.template to match the actual dispatch strategy (per hidden_size element, not per expert).
…rosoft/onnxruntime into opt/webgpu-qmoe-fused-1token
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc:251
override_Mcan makedispatch_M != M, but several fast paths selected above (subgroup-matrix, wide-tile, and any path that doesn’t usebroadcast_a_row/dispatch_M) are not broadcast-aware. If a caller passesoverride_M > MwhenM >= kMinMForTileOptimization, the code can select these kernels and then read/write using an M-stride that doesn’t match the actual input/output buffers, leading to incorrect results or OOB GPU accesses. Consider enforcing constraints onoverride_M(e.g., only allowdispatch_M != Mfor the intended broadcast case such asM==1, and requireyto be sized forbatch_count*dispatch_M*N), and/or gate the non-broadcast-aware kernels withdispatch_M == M.
const uint32_t batch_count = onnxruntime::narrow<uint32_t>(helper.OutputOffsets().size());
const uint32_t M = onnxruntime::narrow<uint32_t>(helper.M());
const uint32_t dispatch_M = (override_M > 0) ? override_M : M;
const bool broadcast_a = dispatch_M > M;
const uint32_t N = onnxruntime::narrow<uint32_t>(helper.N());
const uint32_t K = onnxruntime::narrow<uint32_t>(helper.K());
const uint32_t block_size = onnxruntime::narrow<uint32_t>(block_size_op);
// Special case matrix used by bitnets where there is a single scale for the entire
const bool single_scale_weights = (block_size == K * N);
const uint32_t block_size_per_col = single_scale_weights ? K : block_size;
const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col;
const uint32_t blob_size = (block_size_per_col / 8) * static_cast<uint32_t>(nbits);
const uint32_t blob_size_in_words = blob_size / 4;
const uint32_t components_a = GetMaxComponents(K);
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
uint32_t components = GetMaxComponents(N);
// zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)].
// The shader uses a flat linear index to address individual n-bit zero point values.
// Since each column's zero points are byte-aligned in the packed buffer, we must round
// n_blocks_per_col up to the next multiple of (8/nbits) — the number of zero point
// values per byte — so that the linear stride correctly skips byte-boundary padding.
const uint32_t zp_elements_per_byte = 8 / static_cast<uint32_t>(nbits);
uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte;
#if !defined(__wasm__)
int32_t subgroup_matrix_config_index = -1;
// apple|intel - Experimental dawn support for subgroup matrix matmul.
if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) &&
CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast<uint32_t>(nbits), y->DataType() == DataTypeImpl::GetType<MLFloat16>(), subgroup_matrix_config_index)) {
return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast<uint32_t>(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect);
}
#endif
// On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
// DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values).
if (((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, dispatch_M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index, weight_index_indirect);
}
// WideTileProgram
// This program is optimized for Block32 prefill using Tile16x128.
const bool use_wide_tile_program = !has_weight_idx_indirect &&
block_size == 32 &&
components_a == 4 &&
components_b == 4 &&
nbits != 2 &&
M >= kMinMForTileOptimization;
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Add #if defined(USE_WEBGPU) block to RunQMoETest helper to run QMoE tests on WebGpuExecutionProvider (skipped when FC3 weights present). - Add QMoETest_WebGPU_SingleToken: num_rows=1 test that exercises the fused 1-token decode path on WebGPU with SwiGLU activation and zero int4 weights. Verified via printf that it enters the fused 1-token code path.
|
cool, shows 14% gain for decode on a rtx5060ti |
…27998) Fuse the QMoE 1-token decode path to reduce GPU dispatches from 17 (1 + k×4) to 5 (gate + fc1 + swiglu + fc2 + mix), improving token generation throughput by ~21% on Meteor Lake for the gpt-oss-20b MoE model (19 → 23 tps). The QMoE operator processes Mixture-of-Experts layers by selecting top-k experts (k=4) per token. In the original 1-token decode path, each expert is processed serially with 4 dispatches (gather + fc1 + swiglu + fc2 + mix), totaling 17 GPU dispatches per QMoE call. Since each dispatch has M=1, the GPU is underutilized and CPU dispatch overhead dominates. For the 1-token path (num_rows == 1): **Gate1Token** — Select top-k experts and output an [indirect_experts](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html) buffer mapping row index → expert index **Batched fc1 MatMulNBits** — Run a single M=k matmul with [per_row_weight_indirect](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html) mode, where each row selects a different expert's weights via the indirect buffer **SwiGLU** — Apply activation on all k rows at once **Batched fc2 MatMulNBits** — Same per-row expert selection for the down projection **FusedFinalMix** — Accumulate all k weighted expert results into the output Fuse Batched fc1 MatMulNBits + SwiGLU Fuse Batched fc2 MatMulNBits + FusedFinalMix Finally, we only need three shaders: Gate1Token, fused Batched fc1 MatMulNBits, fused batched fc2 MatMulNBits.
Version bump to 1.25.1. This cherry-picks the following commits for the release: | Commit ID | PR Number | Commit Title | |-----------|-----------|-------------| | e532c21 | #27842 | linear attention signature | | 410f5a8 | #27752 | +rotemb, +rmsnorm, reshape->opset-25, transpose->opset-24 | | 0fedb26 | #27907 | Add LinearAttention and CausalConvState ops for Qwen3.5 | | 3ac6040 | #27996 | webgpu support for qwen3.5 | | c36c422 | #27998 | [WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches | | 94f32ec | #27289 | [CORE]: Improve filesystem error messages during Linux device discovery | | dce77a3 | #28118 | Fix lack of auth on python packaging | --------- Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: eserscor <erscor@microsoft.com> Co-authored-by: Sanaa Hamel <sanaahamel@microsoft.com> Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com> Co-authored-by: Stephan Seitz <sseitz@nvidia.com> Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Description:
Summary
Fuse the QMoE 1-token decode path to reduce GPU dispatches from 17 (1 + k×4) to 5 (gate + fc1 + swiglu + fc2 + mix), improving token generation throughput by ~21% on Meteor Lake for the gpt-oss-20b MoE model (19 → 23 tps).
Motivation
The QMoE operator processes Mixture-of-Experts layers by selecting top-k experts (k=4) per token. In the original 1-token decode path, each expert is processed serially with 4 dispatches (gather + fc1 + swiglu + fc2 + mix), totaling 17 GPU dispatches per QMoE call. Since each dispatch has M=1, the GPU is underutilized and CPU dispatch overhead dominates.
Approach
For the 1-token path (num_rows == 1):
Gate1Token — Select top-k experts and output an indirect_experts buffer mapping row index → expert index
Batched fc1 MatMulNBits — Run a single M=k matmul with per_row_weight_indirect mode, where each row selects a different expert's weights via the indirect buffer
SwiGLU — Apply activation on all k rows at once
Batched fc2 MatMulNBits — Same per-row expert selection for the down projection
FusedFinalMix — Accumulate all k weighted expert results into the output
Follow-ups
Fuse Batched fc1 MatMulNBits + SwiGLU
Fuse Batched fc2 MatMulNBits + FusedFinalMix
Finally, we only need three shaders: Gate1Token, fused Batched fc1 MatMulNBits, fused batched fc2 MatMulNBits.