Skip to content

WebGPU: Support Split-K with batch size > 1#28151

Merged
guschmue merged 15 commits intomicrosoft:mainfrom
Jiawei-Shao:split-k-multi-batch
Apr 29, 2026
Merged

WebGPU: Support Split-K with batch size > 1#28151
guschmue merged 15 commits intomicrosoft:mainfrom
Jiawei-Shao:split-k-multi-batch

Conversation

@Jiawei-Shao
Copy link
Copy Markdown
Contributor

@Jiawei-Shao Jiawei-Shao commented Apr 21, 2026

Description

This patch adds the support of Split-K with batch size > 1 by
encoding both batch index and Split-K index in dispatch_z and
decompose them in the shader via:
batch = logical_global_id.z / num_k_splits
split_index = logical_global_id.z % num_k_splits

This patch also adds batch size to the criteria of using Split-K
as increasing batch size will also increasing the parallelism,
reducing the effectiveness of Split-K.

This patch also replaces consteval with constexpr in
ort_version_check.h to workaround a compilation error
about vs2022.

Motivation and Context

With this patch we can improve the performance of
sam-vit-b-decoder-static-fp16-demo (7.5%) on Intel PTL.

This patch adds the support of Split-K with batch size > 1 by
encoding both batch index and Split-K index in dispatch_z and
decompose them in the shader via:
  batch = logical_global_id.z / num_k_splits
  split_index = logical_global_id.z % num_k_splits

This patch also adds batch size to the criteria of using Split-K
as increasing batch size will also increasing the parallelism,
reducing the effectiveness of Split-K.
This reverts commit 0d79450.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds WebGPU Split‑K support for MatMul/Conv|MatMul when batch_size > 1 by packing (batch, split_index) into dispatch_z, and updates Split‑K heuristics to consider batch size (plus a small MSVC workaround in version checking).

Changes:

  • Encode both batch index and Split‑K index into dispatch_z and decode in WGSL using splits_per_batch.
  • Add a batch-size threshold and incorporate batch size into Split‑K gating heuristics.
  • Add new unit tests intended to exercise Split‑K with batched inputs; replace consteval with constexpr in version validation helpers.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
onnxruntime/core/providers/webgpu/math/gemm_utils.cc WGSL Split‑K indexing updated to decode batch/split from logical_global_id.z.
onnxruntime/core/providers/webgpu/math/gemm_packed.cc Minor callsite update (comment/arg annotation) for MakeMatMulPackedVec4Source.
onnxruntime/core/providers/webgpu/math/matmul.cc Compute-side Split‑K dispatch now encodes (batch * splits_per_batch) and passes splits_per_batch as a uniform; batched fill-bias/zero init.
onnxruntime/core/providers/webgpu/math/matmul.h Extend fill-bias/zero program factory to accept batch_size.
onnxruntime/core/providers/webgpu/math/matmul_packed.h Add splits_per_batch and batch_size uniforms to relevant programs.
onnxruntime/core/providers/webgpu/math/matmul_packed.cc Fill-bias/zero shader updated to initialize outputs across all batches.
onnxruntime/core/providers/webgpu/webgpu_utils.h Add max_batch_size_ to Split‑K config.
onnxruntime/core/providers/webgpu/webgpu_utils.cc Split‑K gating updated to include batch size and enforce a max batch threshold.
onnxruntime/core/session/ort_version_check.h Switch consteval to constexpr to work around VS2022/MSVC issues.
onnxruntime/test/providers/cpu/math/matmul_test.cc Add a batched MatMul test intended to hit the Split‑K path.
onnxruntime/test/providers/cpu/nn/conv_op_test.cc Add batched Conv2D “matmul-like” tests for Split‑K (with/without bias).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/webgpu/math/gemm_utils.cc Outdated
Comment thread onnxruntime/test/providers/cpu/nn/conv_op_test.cc Outdated
Comment thread onnxruntime/test/providers/cpu/math/matmul_test.cc Outdated
Comment thread onnxruntime/test/providers/cpu/nn/conv_op_test.cc
Comment thread onnxruntime/test/providers/cpu/nn/conv_op_test.cc
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/webgpu/webgpu_utils.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/math/matmul.cc Outdated
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Apr 22, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/providers/cpu/nn/conv_op_test.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_utils.h
Comment thread onnxruntime/core/providers/webgpu/webgpu_utils.cc
Comment thread onnxruntime/test/providers/cpu/nn/conv_op_test.cc Outdated
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/webgpu/math/matmul.cc Outdated
Comment thread onnxruntime/test/providers/cpu/math/matmul_test.cc
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/session/ort_version_check.h Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_utils.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/math/matmul_packed.cc
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit.

Comment thread onnxruntime/core/providers/webgpu/math/gemm_utils.cc
@Jiawei-Shao Jiawei-Shao requested a review from qjia7 April 27, 2026 07:17
Comment thread onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc Outdated
qjia7
qjia7 previously approved these changes Apr 27, 2026
@Jiawei-Shao Jiawei-Shao requested a review from qjia7 April 29, 2026 00:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants