WebGPU: Support Split-K with batch size > 1#28151
Conversation
This patch adds the support of Split-K with batch size > 1 by encoding both batch index and Split-K index in dispatch_z and decompose them in the shader via: batch = logical_global_id.z / num_k_splits split_index = logical_global_id.z % num_k_splits This patch also adds batch size to the criteria of using Split-K as increasing batch size will also increasing the parallelism, reducing the effectiveness of Split-K.
This reverts commit 0d79450.
There was a problem hiding this comment.
Pull request overview
Adds WebGPU Split‑K support for MatMul/Conv|MatMul when batch_size > 1 by packing (batch, split_index) into dispatch_z, and updates Split‑K heuristics to consider batch size (plus a small MSVC workaround in version checking).
Changes:
- Encode both batch index and Split‑K index into
dispatch_zand decode in WGSL usingsplits_per_batch. - Add a batch-size threshold and incorporate batch size into Split‑K gating heuristics.
- Add new unit tests intended to exercise Split‑K with batched inputs; replace
constevalwithconstexprin version validation helpers.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/webgpu/math/gemm_utils.cc | WGSL Split‑K indexing updated to decode batch/split from logical_global_id.z. |
| onnxruntime/core/providers/webgpu/math/gemm_packed.cc | Minor callsite update (comment/arg annotation) for MakeMatMulPackedVec4Source. |
| onnxruntime/core/providers/webgpu/math/matmul.cc | Compute-side Split‑K dispatch now encodes (batch * splits_per_batch) and passes splits_per_batch as a uniform; batched fill-bias/zero init. |
| onnxruntime/core/providers/webgpu/math/matmul.h | Extend fill-bias/zero program factory to accept batch_size. |
| onnxruntime/core/providers/webgpu/math/matmul_packed.h | Add splits_per_batch and batch_size uniforms to relevant programs. |
| onnxruntime/core/providers/webgpu/math/matmul_packed.cc | Fill-bias/zero shader updated to initialize outputs across all batches. |
| onnxruntime/core/providers/webgpu/webgpu_utils.h | Add max_batch_size_ to Split‑K config. |
| onnxruntime/core/providers/webgpu/webgpu_utils.cc | Split‑K gating updated to include batch size and enforce a max batch threshold. |
| onnxruntime/core/session/ort_version_check.h | Switch consteval to constexpr to work around VS2022/MSVC issues. |
| onnxruntime/test/providers/cpu/math/matmul_test.cc | Add a batched MatMul test intended to hit the Split‑K path. |
| onnxruntime/test/providers/cpu/nn/conv_op_test.cc | Add batched Conv2D “matmul-like” tests for Split‑K (with/without bias). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Description
This patch adds the support of Split-K with batch size > 1 by
encoding both batch index and Split-K index in dispatch_z and
decompose them in the shader via:
batch = logical_global_id.z / num_k_splits
split_index = logical_global_id.z % num_k_splits
This patch also adds batch size to the criteria of using Split-K
as increasing batch size will also increasing the parallelism,
reducing the effectiveness of Split-K.
This patch also replaces
constevalwithconstexprinort_version_check.hto workaround a compilation errorabout vs2022.
Motivation and Context
With this patch we can improve the performance of
sam-vit-b-decoder-static-fp16-demo(7.5%) on Intel PTL.