From 66de1b1efbd58c41944d4771aeba48806bbe141e Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 11 Mar 2026 13:21:11 +0800 Subject: [PATCH 1/5] [webgpu] Pass `is_channels_last` with `std::optional` --- .../contrib_ops/webgpu/bert/attention.cc | 13 +++++++--- .../core/providers/webgpu/math/gemm_packed.cc | 2 +- .../core/providers/webgpu/math/gemm_utils.cc | 11 ++++---- .../core/providers/webgpu/math/gemm_utils.h | 4 +-- .../core/providers/webgpu/math/matmul.cc | 25 ++++++++++++++----- .../core/providers/webgpu/math/matmul.h | 3 ++- .../providers/webgpu/math/matmul_packed.cc | 11 +++++--- .../providers/webgpu/math/matmul_packed.h | 16 ++++++------ onnxruntime/core/providers/webgpu/nn/conv.cc | 6 ++++- .../webgpu/vendor/intel/math/matmul.cc | 16 ++---------- .../webgpu/vendor/intel/math/matmul.h | 3 --- .../core/providers/webgpu/webgpu_utils.cc | 14 +++++++---- .../core/providers/webgpu/webgpu_utils.h | 5 ++-- 13 files changed, 74 insertions(+), 55 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 122cb5396f5c1..7b2a7dfd4f47a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -653,10 +653,17 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte Tensor packed_qkv = context.CreateGPUTensor(input->DataType(), TensorShape(packed_qkv_shape)); // Prepare inputs for MatMul - std::vector matmul_inputs = {input, weights, bias}; - + bool has_bias = bias != nullptr; + std::vector matmul_inputs(has_bias ? 3 : 2); + std::optional is_channels_last; + matmul_inputs[0] = input; + matmul_inputs[1] = weights; + if (has_bias) { + matmul_inputs[2] = bias; + is_channels_last = true; + } // Call MatMul: packed_qkv = input * weights + bias - ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true)); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, is_channels_last)); // Output Q, K, V in BSD format return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 45c462209e30b..f1a58a600fa4f 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -114,7 +114,7 @@ Status ApplyGemmPacked(const Tensor* a, // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. const bool output_is_vec4 = output_components == 4; // The parameter `is_channel_last` is not used for GEMM. - const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K); + const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_channels_last*/ {}, M, N, K); if (need_split_k) { const Tensor* bias = nullptr; uint32_t output_components_in_fill_bias_program = 4; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 573d7b016310f..4af6cb70b7a32 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -41,11 +41,12 @@ void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, void HandleMaybeBiasForMatMul(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, - std::string activation_snippet, - bool is_channels_last) { + std::string_view activation_snippet, + std::optional is_channels_last) { shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; if (bias != nullptr) { - shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n"; + ORT_ENFORCE(is_channels_last.has_value(), "is_channels_last must be set when bias is used"); + shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last.value() ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n"; } shader.AdditionalImplementation() << " " << activation_snippet << "\n" << " " << output.SetByIndices("coords", "value") << "\n"; @@ -192,8 +193,8 @@ void MatMulReadFnSource(ShaderHelper& shader, void MatMulWriteFnSourceForMatMul(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, - std::string activation_snippet, - bool is_channels_last) { + std::string_view activation_snippet, + std::optional is_channels_last) { EmitMatMulWriteFnHeader(shader, output); HandleMaybeBiasForMatMul(shader, output, bias, activation_snippet, is_channels_last); EmitMatMulWriteFnFooter(shader); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index 13c298919194a..71da39a73519e 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -18,8 +18,8 @@ void MatMulReadFnSource(ShaderHelper& shader, void MatMulWriteFnSourceForMatMul(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, - std::string activation_snippet, - bool is_channels_last); + std::string_view activation_snippet, + std::optional is_channels_last); void MatMulWriteFnSourceForGemm(ShaderHelper& shader, const ShaderVariableHelper& output, diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 7bbc84d2c3c87..d84204b14f6a4 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -168,16 +168,19 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { return intel::ApplyMatMulIntel(context, Activation(), inputs, output_tensor); } - return ComputeMatMul(&context, Activation(), inputs, output_tensor, false); + return ComputeMatMul(&context, Activation(), inputs, output_tensor); } Status ComputeMatMul(ComputeContext* context, - const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, + const Activation& activation, std::vector& inputs, Tensor* output_tensor, + std::optional is_channels_last, const TensorShape& input_a_reshape, const TensorShape& input_b_reshape) { const auto* a = inputs[0]; const auto* b = inputs[1]; bool has_bias = inputs.size() > 2; + ORT_ENFORCE(is_channels_last.has_value() == has_bias, "is_channels_last must be set when bias is used, and won't be set when bias is not used"); + TensorShape a_shape = input_a_reshape.NumDimensions() > 0 ? input_a_reshape : a->Shape(); TensorShape b_shape = input_b_reshape.NumDimensions() > 0 ? input_b_reshape : b->Shape(); @@ -248,11 +251,14 @@ Status ComputeMatMul(ComputeContext* context, // Current Split-K implementation relies on atomic operations, which are not deterministic. if (!context->KernelContext().GetUseDeterministicCompute()) { const SplitKConfig& split_k_config = context->GetSplitKConfig(); - const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); - ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); + + if (is_channels_last.has_value()) { + ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); + } // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp); @@ -273,9 +279,16 @@ Status ComputeMatMul(ComputeContext* context, } } - MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner}; + bool is_channels_last_in_cache_hint = true; + std::optional is_channels_last_in_matmul_program; + if (use_bias_in_matmul) { + // `is_channels_last_in_matmul_program` has valid value only when `bias` is used in MatMul. + is_channels_last_in_matmul_program = is_channels_last; + is_channels_last_in_cache_hint = is_channels_last.value(); + } + MatMulProgram matmul_program{activation, is_vec4, elements_per_thread, is_channels_last_in_matmul_program, split_dim_inner}; matmul_program - .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last_in_cache_hint, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 3135c1f8a2457..fa8abb6476037 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -14,7 +14,8 @@ namespace onnxruntime { namespace webgpu { -Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, +Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, + std::optional is_channels_last = {}, const TensorShape& input_a_reshape = TensorShape(), const TensorShape& input_b_reshape = TensorShape()); diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 0883c8ddb95b5..babe00a969e1c 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -26,8 +26,10 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + // `is_channels_last_` must be set when `bias` is used, and won't be set when `bias` is not used. + const bool has_bias = is_channels_last_.has_value(); const ShaderVariableHelper* bias = nullptr; - if (has_bias_) { + if (has_bias) { bias = &shader.AddInput("bias", ShaderUsage::UseUniform); } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); @@ -65,12 +67,15 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& } // Handle bias with `MatMulWriteFnSourceForGemm() or MatMulWriteFnSourceForMatMul()`. - // const uint32_t bias_components = output_components_; if (is_gemm_) { MatMulWriteFnSourceForGemm(shader, output, bias, bias_is_scalar_); } else { // Currently we only support `is_channels_last` to be true and no activation. - MatMulWriteFnSourceForMatMul(shader, output, bias, /*activation_snippet*/ "", /*is_channels_last*/ true); + std::optional is_channels_last; + if (has_bias_) { + is_channels_last = true; + } + MatMulWriteFnSourceForMatMul(shader, output, bias, /*activation_snippet*/ "", is_channels_last); } shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 618fc97d72fe0..945425e281d28 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -13,13 +13,12 @@ namespace onnxruntime { namespace webgpu { class MatMulProgram final : public Program { public: - MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false, uint32_t split_dim_inner = 1) : Program{"MatMul"}, - activation_(activation), - has_bias_{bias}, - is_vec4_{is_vec4}, - elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), - is_channels_last_(is_channels_last), - split_dim_inner_(split_dim_inner) {} + MatMulProgram(const Activation& activation, bool is_vec4, const gsl::span& elements_per_thread, std::optional is_channels_last, uint32_t split_dim_inner = 1) : Program{"MatMul"}, + activation_(activation), + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), + is_channels_last_(is_channels_last), + split_dim_inner_(split_dim_inner) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, @@ -33,10 +32,9 @@ class MatMulProgram final : public Program { private: const Activation activation_; - const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; - bool is_channels_last_ = false; + std::optional is_channels_last_; uint32_t split_dim_inner_ = 1; }; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 697428e1ce140..b2624799116bb 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -235,7 +235,11 @@ Status Conv::ComputeInternal(ComputeContext& context .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); return context.RunProgram(program); } else { - return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); + std::optional applied_is_channels_last; + if (has_bias) { + applied_is_channels_last = is_channels_last; + } + return ComputeMatMul(&context, activation_, matmul_inputs, output, applied_is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); } } // Transpose weights when necessary diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc index 0362deb0fbd6a..12236b775b472 100644 --- a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc @@ -21,14 +21,10 @@ Status MatMulSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const { ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - const ShaderVariableHelper* bias = nullptr; - if (has_bias_) { - bias = &shader.AddInput("bias", ShaderUsage::UseUniform); - } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); // declare the read and write functions MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false); - MatMulWriteFnSourceForMatMul(shader, output, bias, apply_activation, /*is_channels_last = */ false); + MatMulWriteFnSourceForMatMul(shader, output, nullptr, apply_activation, {}); // generate the main function ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, &batch_dims, is_vec4_)); return Status::OK(); @@ -44,7 +40,6 @@ Status ApplyMatMulIntel(ComputeContext& context, Tensor* output) { const auto* a = inputs[0]; const auto* b = inputs[1]; - bool has_bias = inputs.size() > 2; TensorShape a_shape = a->Shape(); TensorShape b_shape = b->Shape(); @@ -108,7 +103,7 @@ Status ApplyMatMulIntel(ComputeContext& context, const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, b_components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); - MatMulSubgroupProgram program{activation, has_bias, is_vec4, elements_per_thread}; + MatMulSubgroupProgram program{activation, is_vec4, elements_per_thread}; program .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-")) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, a_components}, @@ -119,13 +114,6 @@ Status ApplyMatMulIntel(ComputeContext& context, .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1); - if (has_bias) { - auto bias_components = 1; - const auto* bias = inputs[2]; - TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); - program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); - } - return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h index 2a8333e3e912b..4b326784927cc 100644 --- a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h @@ -15,12 +15,10 @@ namespace intel { class MatMulSubgroupProgram final : public Program { public: MatMulSubgroupProgram(const Activation& activation, - bool bias, bool is_vec4, const gsl::span& elements_per_thread) : Program{"MatMulSubgroup"}, activation_(activation), - has_bias_{bias}, is_vec4_{is_vec4}, elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} @@ -31,7 +29,6 @@ class MatMulSubgroupProgram final : public Program { private: const Activation activation_; - const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index e48c8c833c7e3..a9b06a7e8e8a5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -74,8 +74,7 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - bool is_gemm, - bool is_channels_last, + std::optional is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const { @@ -89,9 +88,14 @@ bool SplitKConfig::UseSplitK( use_split_k &= activation_kind == ActivationKind::None; use_split_k &= is_vec4; use_split_k &= batch_size == 1; - // Now `is_channels_last` is only supported because we only generate vec4 shaders in - // `MatMulFillBiasOrZeroBeforeSplitKProgram` when `is_gemm` is false. - use_split_k &= (is_channels_last || is_gemm); + + // Now we only need `is_channels_last` in `Conv|MatMul` with `bias`. We don't need to care about + // it in other places (`GEMM`, `MatMul` and `Conv|MatMul` without `bias`). + // When `is_channels_last` has valid value we only accept `is_channels_last` to be true because + // we only generate `vec4` shaders in `MatMulFillBiasOrZeroBeforeSplitKProgram`. + if (is_channels_last.has_value()) { + use_split_k &= is_channels_last.value(); + } // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 6c72fd07938d5..730728fc67e62 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/common/common.h" #include "core/framework/tensor.h" #include "core/framework/tensor_shape.h" @@ -106,8 +107,8 @@ class SplitKConfig { explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( - bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool is_gemm, - bool is_channels_last, uint32_t dim_a_outer, + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + std::optional is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; uint32_t GetSplitDimInner() const; From 0b37c1d68cb76de3da009b11ecd7933113b2886f Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 27 Mar 2026 15:33:50 +0800 Subject: [PATCH 2/5] [webgpu] Set `is_channels_last` to true by default in `ComputeMatMul` This patch sets `is_channels_last` to true by default in the parameter of `ComputeMatMul` and ignores it in `UseSplitK` when there is no `bias`. --- .../contrib_ops/webgpu/bert/attention.cc | 13 +++--------- .../core/providers/webgpu/math/gemm_packed.cc | 2 +- .../core/providers/webgpu/math/gemm_utils.cc | 11 +++++----- .../core/providers/webgpu/math/gemm_utils.h | 4 ++-- .../core/providers/webgpu/math/matmul.cc | 20 +++++-------------- .../core/providers/webgpu/math/matmul.h | 3 +-- .../providers/webgpu/math/matmul_packed.cc | 11 +++------- .../providers/webgpu/math/matmul_packed.h | 16 ++++++++------- onnxruntime/core/providers/webgpu/nn/conv.cc | 6 +----- .../webgpu/vendor/intel/math/matmul.cc | 16 +++++++++++++-- .../webgpu/vendor/intel/math/matmul.h | 3 +++ .../core/providers/webgpu/webgpu_utils.cc | 10 ++++++---- .../core/providers/webgpu/webgpu_utils.h | 5 ++--- 13 files changed, 55 insertions(+), 65 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 7b2a7dfd4f47a..755bd0c60452f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -653,17 +653,10 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte Tensor packed_qkv = context.CreateGPUTensor(input->DataType(), TensorShape(packed_qkv_shape)); // Prepare inputs for MatMul - bool has_bias = bias != nullptr; - std::vector matmul_inputs(has_bias ? 3 : 2); - std::optional is_channels_last; - matmul_inputs[0] = input; - matmul_inputs[1] = weights; - if (has_bias) { - matmul_inputs[2] = bias; - is_channels_last = true; - } + std::vector matmul_inputs = {input, weights, bias}; + // Call MatMul: packed_qkv = input * weights + bias - ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, is_channels_last)); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv)); // Output Q, K, V in BSD format return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index f1a58a600fa4f..154326f68b7e9 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -114,7 +114,7 @@ Status ApplyGemmPacked(const Tensor* a, // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. const bool output_is_vec4 = output_components == 4; // The parameter `is_channel_last` is not used for GEMM. - const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_channels_last*/ {}, M, N, K); + const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, need_handle_bias, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K); if (need_split_k) { const Tensor* bias = nullptr; uint32_t output_components_in_fill_bias_program = 4; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 4af6cb70b7a32..573d7b016310f 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -41,12 +41,11 @@ void HandleMaybeHaveBiasForGEMM(ShaderHelper& shader, void HandleMaybeBiasForMatMul(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, - std::string_view activation_snippet, - std::optional is_channels_last) { + std::string activation_snippet, + bool is_channels_last) { shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; if (bias != nullptr) { - ORT_ENFORCE(is_channels_last.has_value(), "is_channels_last must be set when bias is used"); - shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last.value() ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n"; + shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n"; } shader.AdditionalImplementation() << " " << activation_snippet << "\n" << " " << output.SetByIndices("coords", "value") << "\n"; @@ -193,8 +192,8 @@ void MatMulReadFnSource(ShaderHelper& shader, void MatMulWriteFnSourceForMatMul(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, - std::string_view activation_snippet, - std::optional is_channels_last) { + std::string activation_snippet, + bool is_channels_last) { EmitMatMulWriteFnHeader(shader, output); HandleMaybeBiasForMatMul(shader, output, bias, activation_snippet, is_channels_last); EmitMatMulWriteFnFooter(shader); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index 71da39a73519e..13c298919194a 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -18,8 +18,8 @@ void MatMulReadFnSource(ShaderHelper& shader, void MatMulWriteFnSourceForMatMul(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, - std::string_view activation_snippet, - std::optional is_channels_last); + std::string activation_snippet, + bool is_channels_last); void MatMulWriteFnSourceForGemm(ShaderHelper& shader, const ShaderVariableHelper& output, diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index d84204b14f6a4..e405745dcf925 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -172,15 +172,12 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { } Status ComputeMatMul(ComputeContext* context, - const Activation& activation, std::vector& inputs, Tensor* output_tensor, - std::optional is_channels_last, + const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, const TensorShape& input_a_reshape, const TensorShape& input_b_reshape) { const auto* a = inputs[0]; const auto* b = inputs[1]; bool has_bias = inputs.size() > 2; - ORT_ENFORCE(is_channels_last.has_value() == has_bias, "is_channels_last must be set when bias is used, and won't be set when bias is not used"); - TensorShape a_shape = input_a_reshape.NumDimensions() > 0 ? input_a_reshape : a->Shape(); TensorShape b_shape = input_b_reshape.NumDimensions() > 0 ? input_b_reshape : b->Shape(); @@ -251,12 +248,12 @@ Status ComputeMatMul(ComputeContext* context, // Current Split-K implementation relies on atomic operations, which are not deterministic. if (!context->KernelContext().GetUseDeterministicCompute()) { const SplitKConfig& split_k_config = context->GetSplitKConfig(); - const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, has_bias, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); - if (is_channels_last.has_value()) { + if (has_bias) { ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); } @@ -279,16 +276,9 @@ Status ComputeMatMul(ComputeContext* context, } } - bool is_channels_last_in_cache_hint = true; - std::optional is_channels_last_in_matmul_program; - if (use_bias_in_matmul) { - // `is_channels_last_in_matmul_program` has valid value only when `bias` is used in MatMul. - is_channels_last_in_matmul_program = is_channels_last; - is_channels_last_in_cache_hint = is_channels_last.value(); - } - MatMulProgram matmul_program{activation, is_vec4, elements_per_thread, is_channels_last_in_matmul_program, split_dim_inner}; + MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner}; matmul_program - .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last_in_cache_hint, split_dim_inner) + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}}) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index fa8abb6476037..d15e36ffa3d85 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -14,8 +14,7 @@ namespace onnxruntime { namespace webgpu { -Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, - std::optional is_channels_last = {}, +Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last = true, const TensorShape& input_a_reshape = TensorShape(), const TensorShape& input_b_reshape = TensorShape()); diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index babe00a969e1c..0883c8ddb95b5 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -26,10 +26,8 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - // `is_channels_last_` must be set when `bias` is used, and won't be set when `bias` is not used. - const bool has_bias = is_channels_last_.has_value(); const ShaderVariableHelper* bias = nullptr; - if (has_bias) { + if (has_bias_) { bias = &shader.AddInput("bias", ShaderUsage::UseUniform); } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); @@ -67,15 +65,12 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& } // Handle bias with `MatMulWriteFnSourceForGemm() or MatMulWriteFnSourceForMatMul()`. + // const uint32_t bias_components = output_components_; if (is_gemm_) { MatMulWriteFnSourceForGemm(shader, output, bias, bias_is_scalar_); } else { // Currently we only support `is_channels_last` to be true and no activation. - std::optional is_channels_last; - if (has_bias_) { - is_channels_last = true; - } - MatMulWriteFnSourceForMatMul(shader, output, bias, /*activation_snippet*/ "", is_channels_last); + MatMulWriteFnSourceForMatMul(shader, output, bias, /*activation_snippet*/ "", /*is_channels_last*/ true); } shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index 945425e281d28..618fc97d72fe0 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -13,12 +13,13 @@ namespace onnxruntime { namespace webgpu { class MatMulProgram final : public Program { public: - MatMulProgram(const Activation& activation, bool is_vec4, const gsl::span& elements_per_thread, std::optional is_channels_last, uint32_t split_dim_inner = 1) : Program{"MatMul"}, - activation_(activation), - is_vec4_{is_vec4}, - elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), - is_channels_last_(is_channels_last), - split_dim_inner_(split_dim_inner) {} + MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false, uint32_t split_dim_inner = 1) : Program{"MatMul"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), + is_channels_last_(is_channels_last), + split_dim_inner_(split_dim_inner) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, @@ -32,9 +33,10 @@ class MatMulProgram final : public Program { private: const Activation activation_; + const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; - std::optional is_channels_last_; + bool is_channels_last_ = false; uint32_t split_dim_inner_ = 1; }; diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index b2624799116bb..697428e1ce140 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -235,11 +235,7 @@ Status Conv::ComputeInternal(ComputeContext& context .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); return context.RunProgram(program); } else { - std::optional applied_is_channels_last; - if (has_bias) { - applied_is_channels_last = is_channels_last; - } - return ComputeMatMul(&context, activation_, matmul_inputs, output, applied_is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); + return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); } } // Transpose weights when necessary diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc index 12236b775b472..0362deb0fbd6a 100644 --- a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc @@ -21,10 +21,14 @@ Status MatMulSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const { ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const ShaderVariableHelper* bias = nullptr; + if (has_bias_) { + bias = &shader.AddInput("bias", ShaderUsage::UseUniform); + } std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); // declare the read and write functions MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false); - MatMulWriteFnSourceForMatMul(shader, output, nullptr, apply_activation, {}); + MatMulWriteFnSourceForMatMul(shader, output, bias, apply_activation, /*is_channels_last = */ false); // generate the main function ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, &batch_dims, is_vec4_)); return Status::OK(); @@ -40,6 +44,7 @@ Status ApplyMatMulIntel(ComputeContext& context, Tensor* output) { const auto* a = inputs[0]; const auto* b = inputs[1]; + bool has_bias = inputs.size() > 2; TensorShape a_shape = a->Shape(); TensorShape b_shape = b->Shape(); @@ -103,7 +108,7 @@ Status ApplyMatMulIntel(ComputeContext& context, const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, b_components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); - MatMulSubgroupProgram program{activation, is_vec4, elements_per_thread}; + MatMulSubgroupProgram program{activation, has_bias, is_vec4, elements_per_thread}; program .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-")) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, a_components}, @@ -114,6 +119,13 @@ Status ApplyMatMulIntel(ComputeContext& context, .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1); + if (has_bias) { + auto bias_components = 1; + const auto* bias = inputs[2]; + TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + } + return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h index 4b326784927cc..2a8333e3e912b 100644 --- a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h @@ -15,10 +15,12 @@ namespace intel { class MatMulSubgroupProgram final : public Program { public: MatMulSubgroupProgram(const Activation& activation, + bool bias, bool is_vec4, const gsl::span& elements_per_thread) : Program{"MatMulSubgroup"}, activation_(activation), + has_bias_{bias}, is_vec4_{is_vec4}, elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} @@ -29,6 +31,7 @@ class MatMulSubgroupProgram final : public Program { private: const Activation activation_; + const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index a9b06a7e8e8a5..d1a3529c2094a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -74,7 +74,9 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - std::optional is_channels_last, + bool has_bias, + bool is_gemm, + bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const { @@ -91,10 +93,10 @@ bool SplitKConfig::UseSplitK( // Now we only need `is_channels_last` in `Conv|MatMul` with `bias`. We don't need to care about // it in other places (`GEMM`, `MatMul` and `Conv|MatMul` without `bias`). - // When `is_channels_last` has valid value we only accept `is_channels_last` to be true because + // When `is_channels_last` has valid value `is_channels_last` is required to be true because // we only generate `vec4` shaders in `MatMulFillBiasOrZeroBeforeSplitKProgram`. - if (is_channels_last.has_value()) { - use_split_k &= is_channels_last.value(); + if (has_bias && !is_gemm) { + use_split_k &= is_channels_last; } // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 730728fc67e62..0df3515842336 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -4,7 +4,6 @@ #pragma once #include -#include #include "core/common/common.h" #include "core/framework/tensor.h" #include "core/framework/tensor_shape.h" @@ -107,8 +106,8 @@ class SplitKConfig { explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( - bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - std::optional is_channels_last, uint32_t dim_a_outer, + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool has_bias, + bool is_gemm, bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; uint32_t GetSplitDimInner() const; From dc6170c0fef94b7f0d5e16374db2d4156628308a Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Mon, 30 Mar 2026 15:10:17 +0800 Subject: [PATCH 3/5] Address reviewer's comments --- onnxruntime/core/providers/webgpu/math/gemm_packed.cc | 2 +- onnxruntime/core/providers/webgpu/math/matmul.cc | 2 +- onnxruntime/core/providers/webgpu/webgpu_utils.cc | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 154326f68b7e9..aca00da30d9bc 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -113,7 +113,7 @@ Status ApplyGemmPacked(const Tensor* a, const SplitKConfig& split_k_config = context.GetSplitKConfig(); // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. const bool output_is_vec4 = output_components == 4; - // The parameter `is_channel_last` is not used for GEMM. + // We need to use `true` as `is_channels_last` to meet the requirement in `UseSplitK`. const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, need_handle_bias, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K); if (need_split_k) { const Tensor* bias = nullptr; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index e405745dcf925..65affb0c2d024 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -251,7 +251,7 @@ Status ComputeMatMul(ComputeContext* context, const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, has_bias, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); - ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); + ORT_ENFORCE(is_vec4, "Split-K MatMul requires vec4 packing."); if (has_bias) { ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index d1a3529c2094a..6e47e0fd25817 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -91,11 +91,12 @@ bool SplitKConfig::UseSplitK( use_split_k &= is_vec4; use_split_k &= batch_size == 1; - // Now we only need `is_channels_last` in `Conv|MatMul` with `bias`. We don't need to care about - // it in other places (`GEMM`, `MatMul` and `Conv|MatMul` without `bias`). - // When `is_channels_last` has valid value `is_channels_last` is required to be true because - // we only generate `vec4` shaders in `MatMulFillBiasOrZeroBeforeSplitKProgram`. - if (has_bias && !is_gemm) { + // `is_channels_last` should only affect Split-K gating when bias is applied in the non-GEMM + // MatMul/Conv|MatMul path. For GEMM and for MatMul or Conv|MatMul without bias, we need to + // use `true` as `is_channels_last` to make `UseSplitK` ignore `is_channels_last`. + // When `is_channels_last` has a valid value here, it is required to be true because we only + // generate `vec4` shaders in `MatMulFillBiasOrZeroBeforeSplitKProgram`. + if (has_bias) { use_split_k &= is_channels_last; } From b2849145dcd53a3f9c4b1b0c8e0bbdd3d970c3cb Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Mon, 30 Mar 2026 15:13:57 +0800 Subject: [PATCH 4/5] Remove `is_gemm` in `UseSplitK` --- onnxruntime/core/providers/webgpu/math/gemm_packed.cc | 2 +- onnxruntime/core/providers/webgpu/math/matmul.cc | 2 +- onnxruntime/core/providers/webgpu/webgpu_utils.cc | 1 - onnxruntime/core/providers/webgpu/webgpu_utils.h | 3 +-- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index aca00da30d9bc..ed9f7b9366fd4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -114,7 +114,7 @@ Status ApplyGemmPacked(const Tensor* a, // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. const bool output_is_vec4 = output_components == 4; // We need to use `true` as `is_channels_last` to meet the requirement in `UseSplitK`. - const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, need_handle_bias, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K); + const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, need_handle_bias, /*is_channels_last*/ true, M, N, K); if (need_split_k) { const Tensor* bias = nullptr; uint32_t output_components_in_fill_bias_program = 4; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 65affb0c2d024..34a6eac4af8d7 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -248,7 +248,7 @@ Status ComputeMatMul(ComputeContext* context, // Current Split-K implementation relies on atomic operations, which are not deterministic. if (!context->KernelContext().GetUseDeterministicCompute()) { const SplitKConfig& split_k_config = context->GetSplitKConfig(); - const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, has_bias, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, has_bias, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); ORT_ENFORCE(is_vec4, "Split-K MatMul requires vec4 packing."); diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 6e47e0fd25817..04a1c16090928 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -75,7 +75,6 @@ bool SplitKConfig::UseSplitK( ActivationKind activation_kind, uint64_t batch_size, bool has_bias, - bool is_gemm, bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 0df3515842336..1abe527abcec4 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -107,8 +107,7 @@ class SplitKConfig { bool UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool has_bias, - bool is_gemm, bool is_channels_last, uint32_t dim_a_outer, - uint32_t dim_b_outer, uint32_t dim_inner) const; + bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; uint32_t GetSplitDimInner() const; From cb0de61ad5ac4b0a301fdc1891d863afd6ddd03b Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 31 Mar 2026 11:13:10 +0800 Subject: [PATCH 5/5] Address reviewer's comments --- onnxruntime/core/providers/webgpu/math/gemm_packed.cc | 2 +- onnxruntime/core/providers/webgpu/math/matmul.cc | 2 +- onnxruntime/core/providers/webgpu/webgpu_utils.cc | 9 +++------ onnxruntime/core/providers/webgpu/webgpu_utils.h | 4 ++-- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index ed9f7b9366fd4..79a4f1f73902b 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -114,7 +114,7 @@ Status ApplyGemmPacked(const Tensor* a, // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. const bool output_is_vec4 = output_components == 4; // We need to use `true` as `is_channels_last` to meet the requirement in `UseSplitK`. - const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, need_handle_bias, /*is_channels_last*/ true, M, N, K); + const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, M, N, K); if (need_split_k) { const Tensor* bias = nullptr; uint32_t output_components_in_fill_bias_program = 4; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 34a6eac4af8d7..af488f2c23a30 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -248,7 +248,7 @@ Status ComputeMatMul(ComputeContext* context, // Current Split-K implementation relies on atomic operations, which are not deterministic. if (!context->KernelContext().GetUseDeterministicCompute()) { const SplitKConfig& split_k_config = context->GetSplitKConfig(); - const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, has_bias, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, dim_a_outer, dim_b_outer, dim_inner, is_channels_last); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); ORT_ENFORCE(is_vec4, "Split-K MatMul requires vec4 packing."); diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 04a1c16090928..5127801ca8451 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -74,11 +74,10 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - bool has_bias, - bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, - uint32_t dim_inner) const { + uint32_t dim_inner, + bool is_channels_last) const { if (!enable_split_k_) { return false; } @@ -95,9 +94,7 @@ bool SplitKConfig::UseSplitK( // use `true` as `is_channels_last` to make `UseSplitK` ignore `is_channels_last`. // When `is_channels_last` has a valid value here, it is required to be true because we only // generate `vec4` shaders in `MatMulFillBiasOrZeroBeforeSplitKProgram`. - if (has_bias) { - use_split_k &= is_channels_last; - } + use_split_k &= is_channels_last; // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 1abe527abcec4..cbceaf2be120d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -106,8 +106,8 @@ class SplitKConfig { explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( - bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool has_bias, - bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last = true) const; uint32_t GetSplitDimInner() const;