Skip to content

webgpu: merge batchA into M dimension when batchB==1#28197

Merged
guschmue merged 2 commits intomicrosoft:mainfrom
xhcao:merge-all-outer-dims
Apr 29, 2026
Merged

webgpu: merge batchA into M dimension when batchB==1#28197
guschmue merged 2 commits intomicrosoft:mainfrom
xhcao:merge-all-outer-dims

Conversation

@xhcao
Copy link
Copy Markdown
Contributor

@xhcao xhcao commented Apr 23, 2026

When M is small and batchA is large, there are some invalid elements in each tile, merge batchA into M dimesion would reduce the workgroup count.

Description

Motivation and Context

When M is small and batchA is large, there are some invalid
elements in each tile, merge batchA into M dimesion would
reduce the workgroup count.
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Apr 23, 2026
@guschmue guschmue requested a review from Copilot April 23, 2026 16:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the WebGPU MatMul implementation to flatten A’s batch dimensions into the effective M dimension when B has no batching (batchB==1), aiming to reduce workgroup overhead for cases with small M and large batchA. It also adds WebGPU-specific regression tests for additional 3D batched MatMul shapes.

Changes:

  • WebGPU MatMul: reshape A/B and treat output as {1, batchA*M, N} when batchA != 1 && batchB == 1 (applies to both the generic and Intel subgroup paths).
  • Add WebGPU-only MatMul test cases covering 3D inputs with batchA=3, M=2 and N in {3,4}.

Reviewed changes

Copilot reviewed 2 out of 3 changed files in this pull request and generated 2 comments.

File Description
onnxruntime/test/providers/cpu/math/matmul_test.cc Adds WebGPU-only test cases for 3D batched MatMul with batchA>1 and M>1.
onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc Extends the Intel subgroup MatMul reshape optimization from M==1 to all batchA!=1 && batchB==1 cases.
onnxruntime/core/providers/webgpu/math/matmul.cc Extends the generic WebGPU MatMul reshape optimization similarly, flattening batch dims into M when batchB==1.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/webgpu/math/matmul.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/math/matmul.cc
@xhcao
Copy link
Copy Markdown
Contributor Author

xhcao commented Apr 24, 2026

@qjia7 @jchen10 PTAL. Do you think there is necessary to add a threshold as Copilot mentioned? Are there correct issues when changing dispatch geometry here?

Comment thread onnxruntime/core/providers/webgpu/math/matmul.cc Outdated
@qjia7
Copy link
Copy Markdown
Contributor

qjia7 commented Apr 24, 2026

@qjia7 @jchen10 PTAL. Do you think there is necessary to add a threshold as Copilot mentioned? Are there correct issues when changing dispatch geometry here?

No. It should be a general optimization. Just update your comments in code.

Copy link
Copy Markdown
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please also update your PR description. Thanks.

@guschmue guschmue enabled auto-merge (squash) April 29, 2026 16:43
@guschmue guschmue merged commit 6f47410 into microsoft:main Apr 29, 2026
85 of 87 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants