webgpu: Refactor SubgroupMatrixMatMulNBits to vendor-agnostic config …#28109
Merged
webgpu: Refactor SubgroupMatrixMatMulNBits to vendor-agnostic config …#28109
Conversation
…+ add NVIDIA 16x16x16
Refactor subgroup matrix MatMulNBits support from vendor-specific (Apple/Intel)
to a vendor-agnostic config-based approach. Any GPU reporting a matching
subgroup matrix config from Dawn is now automatically supported.
Key changes:
- Replace vendor-specific config table with SupportedSubgroupMatrixConfig struct
containing {componentType, resultComponentType, M, N, K, subgroupMinSize,
subgroupMaxSize, needsPrepack}. No architecture or backendType required.
- Remove vendor_ member from SubgroupMatrixMatMulNBitsProgram. Shader selection
is now driven by config dimensions (8x8x8, 8x16x16, 16x16x16).
- Remove vendor gate in matmul_nbits.cc call site.
- Rename shader templates: _apple -> _8x8x8, _intel -> _8x16x16.
- Add new 16x16x16 shader template for NVIDIA Blackwell (RTX 5080).
- 4 subgroups x 32 lanes = 128 threads per workgroup
- 64x64 tile with 16x16 subgroup matrices
- Bounds-checked output via scratch buffer for partial M tiles
- Fix prepack shader OOB reads: add scalar fallback with zero-fill for
partial blocks where M is not a multiple of kSgMatM.
- Prioritize larger configs (16x16x16 > 8x16x16 > 8x8x8) when multiple match.
Verified on NVIDIA RTX 5080 (Blackwell, Vulkan backend):
- Correctness: model-qa.py with phi4-graph-prune produces identical output
to D3D12 baseline
- Prefill (phi4, l=1024):
- D3D12 DP4A baseline: 3,006 tps
- Vulkan DP4A baseline: 6,155 tps
- Vulkan tensor core (this change): 6,759 tps (+10% vs Vulkan DP4A, +125% vs D3D12)
- NVIDIA reports ChromiumExperimentalSubgroupMatrix with F16/F16 16x16x16 config
…barrier placement - Use fast subgroupMatrixStore directly to output for full M blocks (sg_m_base + kSgMatM <= M), avoiding scratch overhead for the common case. - Use scratch + scalar write only for partial M blocks at the boundary. - Move workgroupBarrier outside the if/else to avoid divergent barrier (WGSL disallows workgroupBarrier in non-uniform control flow). - Make scratch array unconditional (needed for both bias and non-bias paths). This fixes the Invalid ShaderModule crash that occurred when the barrier was inside a branch that different subgroups could take different sides of.
Use larger 128x128 tiles (vs 64x64) for NVIDIA Blackwell 16x16x16 config to improve prefill throughput. Key changes: - New WGSL template with 2x2 subgroup grid, each handling 64x64 subtile - Load A directly from prepacked global memory (no shared memory) - Dequant B to shared memory with padded stride (SHMEM_STRIDE=40) - Update dispatch to ceil(N/128) x ceil(M/128)
…port - Use exact subgroup size matching (==) instead of range (>=) - Add F32 8x8x8 config for Apple parity - Pass is_fp16 into IsSubgroupMatrixConfigSupported to correctly skip F16 configs when output is F32 - Simplify accuracy_level check to apply uniformly - Fix missing closing brace in CanApplySubgroupMatrixMatMulNBits
The prepack buffer was sized to ceil(M/sg_mat_m)*sg_mat_m rows, but the matmul shader dispatches workgroups covering ceil(M/tile_size_a)*tile_size_a rows. When M < tile_size_a (e.g. M=32 with 128x128 tiles), subgroups in the matmul shader would read past the end of the prepack buffer, causing a device-lost error. Fix: move tile size computation before prepack allocation and pad the prepack buffer to the workgroup tile size. Also remove unnecessary zero-fill in the prepack shader for OOB rows — the matmul shader already bounds-checks before storing output, so fully OOB prepack blocks can skip entirely.
Contributor
There was a problem hiding this comment.
Pull request overview
Refactors WebGPU SubgroupMatrix MatMulNBits support from vendor-gated (Apple/Intel) logic to a vendor-agnostic, adapter-reported subgroup-matrix-config matching approach, and adds a new 16x16x16 (128x128 tile) shader variant targeting NVIDIA Blackwell-class GPUs.
Changes:
- Introduces a vendor-agnostic
SupportedSubgroupMatrixConfigtable and matching logic driven by Dawn-reported subgroup matrix configs. - Renames/rewrites subgroup-matrix WGSL templates into dimension-keyed variants (8x8x8, 8x16x16) and adds a new 16x16x16_128 shader.
- Updates prepack WGSL to avoid OOB reads via a scalar fallback for partial M blocks, and removes vendor gating at the MatMulNBits call site.
Reviewed changes
Copilot reviewed 5 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template | Adds partial-block handling in prepack to avoid OOB reads and updates documentation about padding behavior. |
| onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x8x8.wgsl.template | Adds the 8x8x8 subgroup-matrix matmul shader template (formerly vendor-specific). |
| onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template | Adds the 8x16x16 subgroup-matrix matmul shader template (formerly Intel-specific). |
| onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template | Adds the 16x16x16 subgroup-matrix matmul shader template using 128x128 tiles and edge-tile bounds-checked stores. |
| onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h | Removes vendor string plumbing from the program API. |
| onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc | Replaces vendor-specific config matching with a vendor-agnostic config table; selects shaders by config dimensions; updates prepack padding logic. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | Removes vendor gate so subgroup-matrix path is enabled for any adapter reporting a supported config. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Add M=100, N=256, K=128, block_size=32 test cases to Float16_4b_Accuracy0 and Float16_4b_Accuracy4. These dimensions meet SubgroupMatrix constraints (block_size=32, N%64==0, K%32==0) and M=100 is non-128-aligned, exercising the prepack buffer padding fix from commit 9581821.
qjia7
added a commit
that referenced
this pull request
Apr 23, 2026
- Skip F32 subgroup matrix configs when output is FP16 (symmetric with the existing F16-skip-when-F32 filter) - Fix misleading prepack comment: not all matmul shaders bounds-check edge tiles, so don't promise they do Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
qjia7
added a commit
that referenced
this pull request
Apr 23, 2026
- Skip F32 subgroup matrix configs when output is FP16 (symmetric with the existing F16-skip-when-F32 filter) - Fix misleading prepack comment: not all matmul shaders bounds-check edge tiles, so don't promise they do
dab1ba9 to
a3f7051
Compare
… N=192 test Remove 63 workgroupBarrier() calls from the bias and non-bias edge-tile output paths in the 16x16x16_128 shader. These barriers are unnecessary because subgroupMatrixStore and subsequent scalar reads from coopmat_stage execute within the same subgroup (lockstep), and each subgroup writes to its own non-overlapping stage_base region. Also add N=192 (partial N tile, not divisible by 128) test cases and remove a duplicate #include in subgroup_matrix_matmul_nbits.h.
- Skip F32 subgroup matrix configs when output is FP16 (symmetric with the existing F16-skip-when-F32 filter) - Fix misleading prepack comment: not all matmul shaders bounds-check edge tiles, so don't promise they do
Contributor
Author
|
cc @jchen10 |
guschmue
approved these changes
Apr 25, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
…+ add NVIDIA 16x16x16
Refactor subgroup matrix MatMulNBits support from vendor-specific (Apple/Intel) to a vendor-agnostic config-based approach. Any GPU reporting a matching subgroup matrix config from Dawn is now automatically supported.
Key changes:
Verified on NVIDIA RTX 5080 (Blackwell, Vulkan backend):
Description
Motivation and Context