Skip to content

[WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches#27998

Merged
guschmue merged 12 commits intomainfrom
opt/webgpu-qmoe-fused-1token
Apr 10, 2026
Merged

[WebGPU EP] Fuse QMoE 1-token decode path to reduce GPU dispatches#27998
guschmue merged 12 commits intomainfrom
opt/webgpu-qmoe-fused-1token

Conversation

@qjia7
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 commented Apr 7, 2026

Description:

Summary

Fuse the QMoE 1-token decode path to reduce GPU dispatches from 17 (1 + k×4) to 5 (gate + fc1 + swiglu + fc2 + mix), improving token generation throughput by ~21% on Meteor Lake for the gpt-oss-20b MoE model (19 → 23 tps).

Motivation

The QMoE operator processes Mixture-of-Experts layers by selecting top-k experts (k=4) per token. In the original 1-token decode path, each expert is processed serially with 4 dispatches (gather + fc1 + swiglu + fc2 + mix), totaling 17 GPU dispatches per QMoE call. Since each dispatch has M=1, the GPU is underutilized and CPU dispatch overhead dominates.

Approach

For the 1-token path (num_rows == 1):

Gate1Token — Select top-k experts and output an indirect_experts buffer mapping row index → expert index
Batched fc1 MatMulNBits — Run a single M=k matmul with per_row_weight_indirect mode, where each row selects a different expert's weights via the indirect buffer
SwiGLU — Apply activation on all k rows at once
Batched fc2 MatMulNBits — Same per-row expert selection for the down projection
FusedFinalMix — Accumulate all k weighted expert results into the output

Follow-ups

Fuse Batched fc1 MatMulNBits + SwiGLU

Fuse Batched fc2 MatMulNBits + FusedFinalMix

Finally, we only need three shaders: Gate1Token, fused Batched fc1 MatMulNBits, fused batched fc2 MatMulNBits.

qjia7 added 8 commits April 7, 2026 11:19
For MoE models with k=4 top experts and 32 total experts, the previous
1-token decode path dispatched 17 GPU programs per QMoE call (1 gate +
4 experts x (MatMul_fc1 + SwiGLU + MatMul_fc2 + FinalMix)). Each expert
was processed sequentially with separate dispatches.

This change fuses the 1-token path into 6 dispatches:
1. Gate (select top-k experts)
2. Replicate hidden_state k times
3. Single batched MatMul_fc1 with M=k, per-row expert selection
4. Single SwiGLU on all k rows
5. Single batched MatMul_fc2 with M=k, per-row expert selection
6. Single fused FinalMix for all k experts

Key changes:
- Add per_row_weight_indirect mode to MatMulNBitsProgram: each M-row
  reads its expert index from weight_index_indirect[a_global] instead
  of using a single uniform weight_index for all rows. This allows one
  MatMul dispatch to process k different experts simultaneously.
- Skip DP4A/WideTile/Subgroup optimized paths when per_row mode is
  active (they don't support per-row expert selection yet).
- Add ReplicateHiddenProgram shader to copy 1 hidden row to k rows.
- Add FusedFinalMix1TokenProgram shader to accumulate all k expert
  results in one dispatch.

Verified: Model produces correct coherent text output.
…x fused QMoE correctness

- Add per_row_weight_indirect to DP4AMatMulNBitsSmallMProgram for fused MoE
  per-row expert selection in the DP4A SmallM shader path
- Fix has_weight_idx in ApplyDP4AMatrixMatMulNBits to account for
  weight_index_indirect (was only checking weight_index != 0)
- Force SmallM path when per_row_weight_indirect is true (Regular DP4A
  does not support per-row)
- Skip WideTile path when per_row_weight_indirect is true (no per-row support)
- Fix FusedFinalMix1Token race condition: dispatch across hidden_size
  instead of k, each thread accumulates all k experts to avoid concurrent
  writes to the same output locations
# Conflicts:
#	onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
#	onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
#	onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
- Add override_M param to ApplyMatMulNBits; compute dispatch_M and broadcast_a
  at the top level instead of overriding M
- Shaders use broadcast_a_row (decoupled from per_row_weight_indirect) to read
  A row 0 when dispatch_M > M
- DP4A quantize uses actual M for quantization; SmallM dispatch uses dispatch_M
- fc1 in QMoE passes override_M=k (A is 1 row, broadcast); fc2 passes 0
  (A has k real rows)
- Eliminates ReplicateHidden dispatch and (k, hidden_size) allocation
- has_weight_idx_indirect now always means per-row expert selection
  (weight_index_indirect[a_global]) in shaders, removing the 3-way #if
- Removed per_row_weight_indirect_ from MatMulNBitsProgram and
  DP4AMatMulNBitsSmallMProgram classes
- Removed per_row_weight_indirect param from ApplyMatMulNBits and
  ApplyDP4AMatrixMatMulNBits
- Routing uses has_weight_idx_indirect to skip WideTile/Subgroup and
  force DP4A SmallM path
@qjia7 qjia7 marked this pull request as ready for review April 8, 2026 05:38
@qjia7 qjia7 requested review from fs-eire and guschmue April 8, 2026 05:38
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Apr 8, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the WebGPU QMoE operator’s 1-token decode fast path by batching the top‑k experts’ FC1/activation/FC2 work and replacing the per‑expert final mix with a single fused mix dispatch, aiming to reduce GPU dispatch overhead and improve generation throughput.

Changes:

  • Add an override_M/broadcast capability to WebGPU MatMulNBits (and DP4A small‑M path) to support dispatching M=k while reading A from a single row.
  • Rework QMoE num_rows == 1 path to: Gate1Token → batched fc1 → batched SwiGLU → batched fc2 → fused final mix shader.
  • Introduce fused_final_mix_1token.wgsl.template and remove the old per‑expert final_mix_1token.wgsl.template.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template Adds broadcast_a_row shader param to allow A-row broadcasting for override_M dispatch.
onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h Extends MatMulNBitsProgram and ApplyMatMulNBits API to support broadcast/override_M.
onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc Wires broadcast_a_row into shader generation and uses dispatch_M for dispatch sizing/uniforms.
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template Adds broadcast_a_row support and switches indirect weight indexing to be per output row.
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h Updates DP4A matmul API to accept dispatch_M and adds broadcast flag to program.
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Threads dispatch_M through small‑M DP4A path and adds shader param plumbing.
onnxruntime/contrib_ops/webgpu/moe/qmoe.cc Implements the fused 1-token QMoE path using batched MatMulNBits and a new fused final mix program.
onnxruntime/contrib_ops/webgpu/moe/fused_final_mix_1token.wgsl.template New shader that accumulates all k expert outputs for a single token in one dispatch.
onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template Removes the old per‑expert final mix shader used by the previous 1-token path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Comment thread onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
Comment thread onnxruntime/contrib_ops/webgpu/moe/fused_final_mix_1token.wgsl.template Outdated
qjia7 added 2 commits April 10, 2026 13:28
…token path

- Add dispatch_M as a separate uniform instead of overloading uniforms.M.
  uniforms.M now always reflects the real A row count (correct batch stride),
  while uniforms.dispatch_M controls grid decomposition and output addressing.
  This fixes the potential out-of-bounds issue when batch_count > 1.
- Skip ZeroTensorProgram for the 1-token fused path since FusedFinalMix
  writes all output elements directly (saves 1 GPU dispatch).
- Fix misleading comment in fused_final_mix_1token.wgsl.template to match
  the actual dispatch strategy (per hidden_size element, not per expert).
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (1)

onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc:251

  • override_M can make dispatch_M != M, but several fast paths selected above (subgroup-matrix, wide-tile, and any path that doesn’t use broadcast_a_row/dispatch_M) are not broadcast-aware. If a caller passes override_M > M when M >= kMinMForTileOptimization, the code can select these kernels and then read/write using an M-stride that doesn’t match the actual input/output buffers, leading to incorrect results or OOB GPU accesses. Consider enforcing constraints on override_M (e.g., only allow dispatch_M != M for the intended broadcast case such as M==1, and require y to be sized for batch_count*dispatch_M*N), and/or gate the non-broadcast-aware kernels with dispatch_M == M.
  const uint32_t batch_count = onnxruntime::narrow<uint32_t>(helper.OutputOffsets().size());
  const uint32_t M = onnxruntime::narrow<uint32_t>(helper.M());
  const uint32_t dispatch_M = (override_M > 0) ? override_M : M;
  const bool broadcast_a = dispatch_M > M;
  const uint32_t N = onnxruntime::narrow<uint32_t>(helper.N());
  const uint32_t K = onnxruntime::narrow<uint32_t>(helper.K());
  const uint32_t block_size = onnxruntime::narrow<uint32_t>(block_size_op);

  // Special case matrix used by bitnets where there is a single scale for the entire
  const bool single_scale_weights = (block_size == K * N);
  const uint32_t block_size_per_col = single_scale_weights ? K : block_size;
  const uint32_t n_blocks_per_col = (K + block_size_per_col - 1) / block_size_per_col;
  const uint32_t blob_size = (block_size_per_col / 8) * static_cast<uint32_t>(nbits);
  const uint32_t blob_size_in_words = blob_size / 4;
  const uint32_t components_a = GetMaxComponents(K);
  const uint32_t components_b = GetMaxComponents(blob_size_in_words);
  uint32_t components = GetMaxComponents(N);
  // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)].
  // The shader uses a flat linear index to address individual n-bit zero point values.
  // Since each column's zero points are byte-aligned in the packed buffer, we must round
  // n_blocks_per_col up to the next multiple of (8/nbits) — the number of zero point
  // values per byte — so that the linear stride correctly skips byte-boundary padding.
  const uint32_t zp_elements_per_byte = 8 / static_cast<uint32_t>(nbits);
  uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte;

#if !defined(__wasm__)
  int32_t subgroup_matrix_config_index = -1;
  // apple|intel - Experimental dawn support for subgroup matrix matmul.
  if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) &&
      CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast<uint32_t>(nbits), y->DataType() == DataTypeImpl::GetType<MLFloat16>(), subgroup_matrix_config_index)) {
    return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast<uint32_t>(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect);
  }
#endif

  // On FP32 only GPUs and Qualcomm GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
  // DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values).
  if (((M >= kMinMForTileOptimization && !has_weight_idx_indirect) || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
      CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
    return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, dispatch_M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index, weight_index_indirect);
  }

  // WideTileProgram
  // This program is optimized for Block32 prefill using Tile16x128.
  const bool use_wide_tile_program = !has_weight_idx_indirect &&
                                     block_size == 32 &&
                                     components_a == 4 &&
                                     components_b == 4 &&
                                     nbits != 2 &&
                                     M >= kMinMForTileOptimization;

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
- Add #if defined(USE_WEBGPU) block to RunQMoETest helper to run QMoE tests
  on WebGpuExecutionProvider (skipped when FC3 weights present).
- Add QMoETest_WebGPU_SingleToken: num_rows=1 test that exercises the fused
  1-token decode path on WebGPU with SwiGLU activation and zero int4 weights.
  Verified via printf that it enters the fused 1-token code path.
@guschmue
Copy link
Copy Markdown
Contributor

cool, shows 14% gain for decode on a rtx5060ti

@guschmue guschmue enabled auto-merge (squash) April 10, 2026 15:11
@guschmue guschmue merged commit c36c422 into main Apr 10, 2026
101 of 102 checks passed
@guschmue guschmue deleted the opt/webgpu-qmoe-fused-1token branch April 10, 2026 16:21
sanaa-hamel-microsoft pushed a commit that referenced this pull request Apr 21, 2026
…27998)

Fuse the QMoE 1-token decode path to reduce GPU dispatches from 17 (1 +
k×4) to 5 (gate + fc1 + swiglu + fc2 + mix), improving token generation
throughput by ~21% on Meteor Lake for the gpt-oss-20b MoE model (19 → 23
tps).
The QMoE operator processes Mixture-of-Experts layers by selecting top-k
experts (k=4) per token. In the original 1-token decode path, each
expert is processed serially with 4 dispatches (gather + fc1 + swiglu +
fc2 + mix), totaling 17 GPU dispatches per QMoE call. Since each
dispatch has M=1, the GPU is underutilized and CPU dispatch overhead
dominates.
For the 1-token path (num_rows == 1):

**Gate1Token** — Select top-k experts and output an
[indirect_experts](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html)
buffer mapping row index → expert index
**Batched fc1 MatMulNBits** — Run a single M=k matmul with
[per_row_weight_indirect](vscode-file://vscode-app/c:/Users/jiajiaqin/AppData/Local/Programs/Microsoft%20VS%20Code/ce099c1ed2/resources/app/out/vs/code/electron-browser/workbench/workbench.html)
mode, where each row selects a different expert's weights via the
indirect buffer
**SwiGLU** — Apply activation on all k rows at once
**Batched fc2 MatMulNBits** — Same per-row expert selection for the down
projection
**FusedFinalMix** — Accumulate all k weighted expert results into the
output
Fuse Batched fc1 MatMulNBits + SwiGLU

Fuse Batched fc2 MatMulNBits + FusedFinalMix

Finally, we only need three shaders: Gate1Token, fused Batched fc1
MatMulNBits, fused batched fc2 MatMulNBits.
sanaa-hamel-microsoft added a commit that referenced this pull request Apr 24, 2026
Version bump to 1.25.1.

This cherry-picks the following commits for the release:

| Commit ID | PR Number | Commit Title |
|-----------|-----------|-------------|
| e532c21 | #27842 | linear attention signature |
| 410f5a8 | #27752 | +rotemb, +rmsnorm, reshape->opset-25,
transpose->opset-24 |
| 0fedb26 | #27907 | Add LinearAttention and CausalConvState ops for
Qwen3.5 |
| 3ac6040 | #27996 | webgpu support for qwen3.5 |
| c36c422 | #27998 | [WebGPU EP] Fuse QMoE 1-token decode path to
reduce GPU dispatches |
| 94f32ec | #27289 | [CORE]: Improve filesystem error messages during
Linux device discovery |
| dce77a3 | #28118 | Fix lack of auth on python packaging |

---------

Co-authored-by: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: eserscor <erscor@microsoft.com>
Co-authored-by: Sanaa Hamel <sanaahamel@microsoft.com>
Co-authored-by: Guenther Schmuelling <guschmue@microsoft.com>
Co-authored-by: Stephan Seitz <sseitz@nvidia.com>
Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider release:1.25.1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants