From d4c42eddcd5070da264aad2721d2761a703e2ac6 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 5 Dec 2025 14:29:39 +0800 Subject: [PATCH 1/6] [WebGPU] Implement Split-K on GEMM --- .../core/providers/webgpu/math/gemm_packed.cc | 67 +++++- .../core/providers/webgpu/math/gemm_packed.h | 8 +- .../core/providers/webgpu/math/gemm_utils.cc | 13 +- .../core/providers/webgpu/math/matmul.cc | 53 +++-- .../core/providers/webgpu/math/matmul.h | 6 +- .../providers/webgpu/math/matmul_packed.cc | 6 +- .../providers/webgpu/math/matmul_packed.h | 13 +- .../core/providers/webgpu/webgpu_utils.cc | 5 +- .../core/providers/webgpu/webgpu_utils.h | 2 +- .../providers/cpu/math/gemm_large_test.cc | 216 ++++++++++++++++++ 10 files changed, 346 insertions(+), 43 deletions(-) create mode 100644 onnxruntime/test/providers/cpu/math/gemm_large_test.cc diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index b81977883dd70..8570cad55dadf 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -5,6 +5,7 @@ #include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/math/matmul.h" #include "core/providers/webgpu/math/matmul_utils.h" #include "core/providers/webgpu/math/gemm_utils.h" @@ -12,7 +13,14 @@ namespace onnxruntime { namespace webgpu { Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { - const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const bool need_split_k = NeedSplitK(); + ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias; + if (need_split_k) { + // When Split-K is enabled, we will declare output as `atomic` to call atomic built-in + // functions on it, so we need below information to correctly compute the index on the output. + output_usage |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + } + const ShaderVariableHelper& output = shader.AddOutput("output", output_usage); // Each thread compute 4*4 elements InlinedVector elements_per_thread = InlinedVector({4, 4, 1}); @@ -26,7 +34,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_); } if (is_vec4_) { - ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_)); + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_)); } else { ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_)); } @@ -35,11 +43,17 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { if (need_handle_bias_) { c = &shader.AddInput("c", ShaderUsage::UseUniform); } - MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_); + + const ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; + MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type); return Status::OK(); } +bool GemmProgram::NeedSplitK() const { + return split_dim_inner_ > 1; +} + Status ApplyGemmPacked(const Tensor* a, const Tensor* b, const Tensor* c, @@ -86,7 +100,44 @@ Status ApplyGemmPacked(const Tensor* a, c_is_scalar = c_shape.Size() == 1; } - GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4}; + ProgramOutput output(y, ProgramTensorMetadataDependency::TypeAndRank, output_components); + uint32_t dispatch_z = 1; + uint32_t split_dim_inner = 1; + + const SplitKConfig& split_k_config = context.GetSplitKConfig(); + // Currently we require the components for Y must be a multiple of 4. + const bool output_is_vec4 = output_components == 4; + // The parameter `is_channel_last` is not used for GEMM. + const bool need_split_k = split_k_config.UseSplitK( + is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K); + if (need_split_k) { + const Tensor* bias = nullptr; + uint32_t output_components_in_fill_bias_program = 4; + if (need_handle_bias) { + bias = c; + output_components_in_fill_bias_program = c_components; + } + const TensorShape output_shape = TensorShape{1, M, N / output_components_in_fill_bias_program}; + + auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram( + bias, y, /*is_gemm*/ true, beta, output_components_in_fill_bias_program, c_is_scalar, output_shape); + ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program)); + + // When Split-K is used, `bias` will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram` + // instead of here. + need_handle_bias = false; + + // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the + // number of splits along `dim_inner`. + split_dim_inner = split_k_config.GetSplitDimInner(); + dispatch_z = (K + split_dim_inner - 1) / split_dim_inner; + + // The output should be declared in atomic types in `MatMulProgram` for the use of atomic + // built-in functions. + output.is_atomic = true; + } + + GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4, split_dim_inner}; if (need_handle_matmul) { program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -101,9 +152,9 @@ Status ApplyGemmPacked(const Tensor* a, const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; - program.CacheHint(alpha, transA, transB, c_is_scalar) - .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) + program.CacheHint(alpha, transA, transB, c_is_scalar, split_dim_inner) + .AddOutput(std::move(output)) + .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, @@ -112,7 +163,7 @@ Status ApplyGemmPacked(const Tensor* a, {K}, /*dim_inner */ {dispatch_x}, /* logical_dispatch_x */ {dispatch_y}, /* logical_dispatch_y */ - {1u}} /* logical_dispatch_z */ + {dispatch_z}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index cb89ccefba313..f81da43e3fe36 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -13,7 +13,7 @@ namespace webgpu { class GemmProgram final : public Program { public: - GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, int c_components, bool c_is_scalar, int output_components, bool is_vec4 = false) + GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, int c_components, bool c_is_scalar, int output_components, bool is_vec4 = false, uint32_t split_dim_inner = 1) : Program{"Gemm"}, transA_{transA}, transB_{transB}, @@ -23,10 +23,13 @@ class GemmProgram final : public Program { c_components_(c_components), c_is_scalar_(c_is_scalar), output_components_(output_components), - is_vec4_(is_vec4) {} + is_vec4_(is_vec4), + split_dim_inner_(split_dim_inner) {} Status GenerateShaderCode(ShaderHelper& sh) const override; + bool NeedSplitK() const; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"alpha", ProgramUniformVariableDataType::Float32}, {"beta", ProgramUniformVariableDataType::Float32}, @@ -51,6 +54,7 @@ class GemmProgram final : public Program { bool c_is_scalar_ = false; int output_components_; bool is_vec4_ = false; + uint32_t split_dim_inner_ = 1; }; Status ApplyGemmPacked(const Tensor* a, diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 73242ed3ff1ba..ba7e9290f8455 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -54,9 +54,14 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, void HandleMatMulWithSplitK( ShaderHelper& shader, + bool is_gemm, const ShaderVariableHelper& output, ProgramVariableDataType output_variable_type) { - shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; + if (is_gemm) { + shader.AdditionalImplementation() << " let coords = vec2(u32(row), u32(colIn));"; + } else { + shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; + } // With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups, // so we must add them with atomic built-in functions. Because currently WebGPU doesn't support @@ -205,8 +210,10 @@ void MatMulWriteFnSource(ShaderHelper& shader, // still need to handle `bias` (and `is_channels_last` in the future) in // `MatMulFillBiasOrZeroBeforeSplitKProgram`. ORT_ENFORCE(bias == nullptr, "Bias is not supported in MatMulProgram when Split-K is enabled."); - ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled."); - HandleMatMulWithSplitK(shader, output, output_variable_type); + if (!is_gemm) { + ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled in non-GEMM ops."); + } + HandleMatMulWithSplitK(shader, is_gemm, output, output_variable_type); } else if (is_gemm) { HandleMaybeHaveBiasForGEMM(shader, output, bias, c_components, output_components, c_is_scalar); } else { diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 72dd235eb820a..cd73613ad070a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -240,14 +240,14 @@ Status ComputeMatMul(ComputeContext* context, uint32_t split_dim_inner = 1; const SplitKConfig& split_k_config = context->GetSplitKConfig(); - const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. - const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp); + const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, /*bias_is_scalar*/ false, output_shape_temp); ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program)); // `bias` has been handled in the execution of `fill_bias_program` so we don't need to set @@ -287,32 +287,45 @@ Status ComputeMatMul(ComputeContext* context, MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, - const TensorShape& output_shape_vec4) { + bool is_gemm, + float beta, + uint32_t output_components, + bool bias_is_scalar, + const TensorShape& output_shape) { const bool has_bias = bias != nullptr; - // Currently we only support bias in vec4 and channels last format for Split-K MatMul. - constexpr uint32_t bias_components = 4; - MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias); + // Currently we only support GEMM and channels last format for MatMul with Split-K. + MatMulFillBiasOrZeroBeforeSplitKProgram program(is_gemm, has_bias, output_components, bias_is_scalar); - const uint32_t dim_a_outer = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]); - const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]); + const uint32_t dim_a_outer = narrow(output_shape[output_shape.NumDimensions() - 2]); + const uint32_t dim_b_outer = narrow(output_shape[output_shape.NumDimensions() - 1]); - // Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for + // Fill one value per invocation. Now we use default workgroup size (64) for // this program. - const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4; - const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; - - // To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar - // instead of vec4, while use `output_shape_vec4` directly as the output shape. - const uint32_t dim_b_outer = narrow(dim_b_outer_vec4 * bias_components); - program.CacheHint(has_bias) - .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) + const uint32_t total_outputs = dim_a_outer * dim_b_outer; + const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; + + TensorShape output_shape_temp; + if (is_gemm) { + // GEMM doesn't have `batch` in its output shape. + output_shape_temp = TensorShape({dim_a_outer, dim_b_outer}); + } else { + const uint32_t batch_size = 1; + output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer}); + } + + // To reuse `MatMulWriteFnSource()` we need to set `dim_b_outer` in components when `output_shape` + // is in `vec4`, while use `output_shape` directly as the output shape. + const uint32_t dim_b_outer_components = narrow(dim_b_outer * output_components); + program.CacheHint(is_gemm, has_bias, output_components, bias_is_scalar) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_temp, static_cast(output_components)}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}}) .SetDispatchGroupSize(dispatch_x); if (has_bias) { - const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); - program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(bias_components)}); + // We always use `c_components` as `output_components` in GEMM, and 4 in MatMul. + const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), output_components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(output_components)}); } return program; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 0b65827be7f17..a1b7a1d34f2ca 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -21,7 +21,11 @@ Status ComputeMatMul(ComputeContext* context, const Activation& activation, std: MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, - const TensorShape& output_shape_vec4); + bool is_gemm, + float beta, + uint32_t output_components, + bool bias_is_scalar, + const TensorShape& output_shape); class MatMul final : public WebGpuKernel { public: diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 80a110c3b505c..e97e0fd6f1058 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -63,13 +63,13 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& // Handle bias with `MatMulWriteFnSource()`. // Here `use_split_k` is false because we just initialize `output` with bias. // `use_split_k` is true only when we do the actual MatMul with Split-K. - // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + const uint32_t bias_components = output_components_; MatMulWriteFnSource( - shader, output, bias, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false, + shader, output, bias, is_gemm_, bias_components, output_components_, bias_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); + shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n"; shader.MainFunctionBody() << R"( - let output_components = 4; let output_id = i32(global_idx); let dim_a_outer = i32(uniforms.dim_a_outer); diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index dbd193bc38f58..c7dd57776272b 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -46,18 +46,25 @@ class MatMulProgram final : public Program { // the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program { public: - explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool has_bias) + explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool is_gemm, bool has_bias, uint32_t output_components, bool bias_is_scalar) : Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"}, - has_bias_(has_bias) { + is_gemm_(is_gemm), + has_bias_(has_bias), + output_components_(output_components), + bias_is_scalar_(bias_is_scalar) { } Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_b_outer", ProgramUniformVariableDataType::Uint32}); + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"beta", ProgramUniformVariableDataType::Float32}); private: + bool is_gemm_ = false; bool has_bias_ = false; + uint32_t output_components_ = 0; + bool bias_is_scalar_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 824cfb02c22f0..4386bdcc94056 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -49,6 +49,7 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_gemm, bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, @@ -64,8 +65,8 @@ bool SplitKConfig::UseSplitK( use_split_k &= is_vec4; use_split_k &= batch_size == 1; // Now `is_channels_last` is only supported because we only generate vec4 shaders in - // `MatMulFillBiasOrZeroBeforeSplitKProgram`. - use_split_k &= is_channels_last; + // `MatMulFillBiasOrZeroBeforeSplitKProgram` when `is_gemm` is false. + use_split_k &= (is_channels_last || is_gemm); // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 0aa47371f6752..960c0b565fa66 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -101,7 +101,7 @@ class SplitKConfig { explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( - bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool is_gemm, bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; diff --git a/onnxruntime/test/providers/cpu/math/gemm_large_test.cc b/onnxruntime/test/providers/cpu/math/gemm_large_test.cc new file mode 100644 index 0000000000000..54c2491a0f8a1 --- /dev/null +++ b/onnxruntime/test/providers/cpu/math/gemm_large_test.cc @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" + +#include "test/providers/provider_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "default_providers.h" + +namespace onnxruntime { +namespace test { + +bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) { + // valid shapes are (,) , (1, N) , (M, 1) , (M, N) + if (bias_shape.NumDimensions() > 2) + return false; + // shape is (1,) or (1, 1), or (,) + if (bias_shape.Size() == 1) + return true; + // valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N), + // In last case no broadcasting needed, so don't fail it + return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) || + (bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) || + (bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N)); +} + +Status ComputeGemmOutputShape(const TensorShape& left, int64_t trans_left, const TensorShape& right, + int64_t trans_right, const TensorShape& bias, int64_t& M, int64_t& K, int64_t& N) { + // dimension check + ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1); + ORT_ENFORCE(right.NumDimensions() == 2); + + for (size_t i = 0; i != left.NumDimensions(); ++i) { + ORT_ENFORCE(left[i] >= 0); + ORT_ENFORCE(left[i] <= std::numeric_limits::max()); + } + + for (size_t i = 0; i != right.NumDimensions(); ++i) { + ORT_ENFORCE(right[i] >= 0); + ORT_ENFORCE(right[i] <= std::numeric_limits::max()); + } + + if (trans_left == 1) { + M = left.NumDimensions() == 2 ? left[1] : left[0]; + K = left.NumDimensions() == 2 ? left[0] : 1; + } else { + M = left.NumDimensions() == 2 ? left[0] : 1; + K = left.NumDimensions() == 2 ? left[1] : left[0]; + } + + N = trans_right == 1 ? N = right[0] : N = right[1]; + int k_dim = trans_right == 1 ? 1 : 0; + + Status status; + if (right[k_dim] != K) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "GEMM: Dimension mismatch, W: ", right.ToString(), + " K: " + std::to_string(K), " N:" + std::to_string(N)); + return status; + } + + if (!IsValidBroadcast(bias, M, N)) { + status = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast"); + return status; + } + + // it is possible the input is empty tensor, for example the output of roipool in fast rcnn. + // it is also possible that K == 0 + ORT_ENFORCE(M >= 0 && K >= 0 && N >= 0); + + return status; +} + +float GetScale(const std::vector& c_vals, const TensorShape& c_shape, int64_t M, int64_t N, int64_t m, int64_t n) { + if (c_vals.empty()) + return 0.0f; + if (c_shape.Size() == 1) + return c_vals[0]; + // valid c_shape (s) are (N,) or (1, N) or (M, 1) or (M, N), + // In last case no broadcasting needed, so don't fail it + if (c_shape.NumDimensions() == 1 && c_shape[0] == N) { + return c_vals[n]; + } + + if (c_shape.NumDimensions() == 2 && c_shape[0] == M) { + if (c_shape[1] == 1) { + return c_vals[m]; + } else if (c_shape[1] == N) { + return c_vals[m * N + n]; + } + } + + if (c_shape.NumDimensions() == 2 && c_shape[0] == 1 && c_shape[1] == N) { + return c_vals[n]; + } + return 0.0f; +} + +Status GetExpectedResult(const int64_t M, const int64_t K, const int64_t N, const std::vector& a_vals, + const std::vector& b_vals, const std::vector& c_vals, + std::vector& expected_vals, const TensorShape& a_shape, int64_t a_trans, + const TensorShape& b_shape, int64_t b_trans, + const TensorShape& c_shape, float alpha, float beta) { + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + if (a_trans == 0 && b_trans == 0) { + sum += a_vals[m * K + k] * b_vals[k * N + n]; + } else if (a_trans == 0 && b_trans == 1) { + sum += a_vals[m * K + k] * b_vals[n * K + k]; + } else if (a_trans == 1 && b_trans == 0) { + sum += a_vals[k * M + m] * b_vals[k * N + n]; + } else { + sum += a_vals[k * M + m] * b_vals[n * K + k]; + } + } + expected_vals[m * N + n] = sum * alpha + GetScale(c_vals, c_shape, M, N, m, n) * beta; + } + } + + return Status::OK(); +} + +template +void RunTestTyped(std::initializer_list a_dims, int64_t a_trans, std::initializer_list b_dims, + int64_t b_trans, std::initializer_list c_dims, float alpha = 1.0f, float beta = 1.0f) { + static_assert(std::is_same_v || std::is_same_v, "unexpected type for T1"); + + int64_t M = 0; + int64_t K = 0; + int64_t N = 0; + TensorShape a_shape = TensorShape(a_dims); + TensorShape b_shape = TensorShape(b_dims); + TensorShape c_shape = TensorShape(c_dims); + ComputeGemmOutputShape(a_shape, a_trans, b_shape, b_trans, c_shape, M, K, N); + + RandomValueGenerator random{1234}; + std::vector a_vals(random.Gaussian(AsSpan(a_dims), 0.0f, 0.25f)); + std::vector b_vals(random.Gaussian(AsSpan(b_dims), 0.0f, 0.25f)); + std::vector c_vals; + if (c_dims.size() > 0) { + c_vals = std::vector(random.Gaussian(AsSpan(c_dims), 0.0f, 0.25f)); + } + std::vector expected_vals(M * N); + GetExpectedResult(M, K, N, a_vals, b_vals, c_vals, expected_vals, a_shape, a_trans, b_shape, b_trans, c_shape, alpha, beta); + + OpTester test("Gemm", version); + test.AddAttribute("transA", a_trans); + test.AddAttribute("transB", b_trans); + test.AddAttribute("alpha", alpha); + test.AddAttribute("beta", beta); + if constexpr (std::is_same_v) { + test.AddInput("A", a_dims, a_vals); + test.AddInput("B", b_dims, b_vals); + if (c_dims.size() != 0) { + test.AddInput("C", c_dims, c_vals); + } + test.AddOutput("Y", {M, N}, expected_vals); + } else if constexpr (std::is_same::value) { + test.AddInput("A", a_dims, FloatsToMLFloat16s(a_vals)); + test.AddInput("B", b_dims, FloatsToMLFloat16s(b_vals)); + if (c_dims.size() != 0) { + test.AddInput("C", c_dims, FloatsToMLFloat16s(c_vals)); + } + test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(expected_vals)); + test.SetOutputAbsErr("Y", 0.055f); + test.SetOutputRelErr("Y", 0.02f); + } + + test.RunWithConfig(); +} + +TEST(Gemm_Large, Float32_SplitK) { + RunTestTyped({16, 1024}, 0, {1024, 191}, 0, {1, 191}); + RunTestTyped({15, 1024}, 0, {1024, 191}, 0, {15, 191}); + RunTestTyped({15, 1024}, 0, {1024, 192}, 0, {15, 1}); + + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}); + + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}, 1.5f, 1.3f); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}, 1.5f, 1.3f); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}, 1.5f, 1.3f); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}, 1.5f, 1.3f); + + RunTestTyped({1024, 16}, 1, {1024, 192}, 0, {192}); + RunTestTyped({16, 1024}, 0, {192, 1024}, 1, {192}); + RunTestTyped({1024, 16}, 1, {192, 1024}, 1, {192}); +} + +TEST(Gemm_Large, Float16_SplitK) { + RunTestTyped({16, 1024}, 0, {1024, 191}, 0, {1, 191}); + RunTestTyped({15, 1024}, 0, {1024, 191}, 0, {15, 191}); + RunTestTyped({15, 1024}, 0, {1024, 192}, 0, {15, 1}); + + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}); + + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}, 1.5f, 1.3f); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}, 1.5f, 1.3f); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}, 1.5f, 1.3f); + RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}, 1.5f, 1.3f); + + RunTestTyped({1024, 16}, 1, {1024, 192}, 0, {192}); + RunTestTyped({16, 1024}, 0, {192, 1024}, 1, {192}); + RunTestTyped({1024, 16}, 1, {192, 1024}, 1, {192}); +} + +} // namespace test +} // namespace onnxruntime From 702e1acbcd8a370fbbf18295112b3e81d277f04d Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Mon, 8 Dec 2025 15:24:21 +0800 Subject: [PATCH 2/6] Fix lint --- onnxruntime/core/providers/webgpu/math/matmul.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index cd73613ad070a..062775e557164 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -305,10 +305,10 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr const uint32_t total_outputs = dim_a_outer * dim_b_outer; const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; - TensorShape output_shape_temp; + TensorShape output_shape_temp; if (is_gemm) { - // GEMM doesn't have `batch` in its output shape. - output_shape_temp = TensorShape({dim_a_outer, dim_b_outer}); + // GEMM doesn't have `batch` in its output shape. + output_shape_temp = TensorShape({dim_a_outer, dim_b_outer}); } else { const uint32_t batch_size = 1; output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer}); From 147ecb0aa43911660b3bd753a68621b7ad228e3c Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Mon, 8 Dec 2025 16:26:07 +0800 Subject: [PATCH 3/6] Improve the comments --- onnxruntime/core/providers/webgpu/math/gemm_packed.cc | 2 +- onnxruntime/core/providers/webgpu/math/matmul.cc | 3 +-- onnxruntime/core/providers/webgpu/math/matmul_packed.h | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 8570cad55dadf..276547bbba5c9 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -105,7 +105,7 @@ Status ApplyGemmPacked(const Tensor* a, uint32_t split_dim_inner = 1; const SplitKConfig& split_k_config = context.GetSplitKConfig(); - // Currently we require the components for Y must be a multiple of 4. + // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. const bool output_is_vec4 = output_components == 4; // The parameter `is_channel_last` is not used for GEMM. const bool need_split_k = split_k_config.UseSplitK( diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 062775e557164..5cbbff04ac819 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -300,8 +300,7 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr const uint32_t dim_a_outer = narrow(output_shape[output_shape.NumDimensions() - 2]); const uint32_t dim_b_outer = narrow(output_shape[output_shape.NumDimensions() - 1]); - // Fill one value per invocation. Now we use default workgroup size (64) for - // this program. + // Fill one value per invocation. Now we use default workgroup size (64) for this program. const uint32_t total_outputs = dim_a_outer * dim_b_outer; const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index c7dd57776272b..618fc97d72fe0 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -46,7 +46,7 @@ class MatMulProgram final : public Program { // the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program { public: - explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool is_gemm, bool has_bias, uint32_t output_components, bool bias_is_scalar) + MatMulFillBiasOrZeroBeforeSplitKProgram(bool is_gemm, bool has_bias, uint32_t output_components, bool bias_is_scalar) : Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"}, is_gemm_(is_gemm), has_bias_(has_bias), From de2ee936d7f82504a367b35a15bf9e0ec109333a Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 9 Dec 2025 10:30:43 +0800 Subject: [PATCH 4/6] Remove test --- .../providers/cpu/math/gemm_large_test.cc | 216 ------------------ 1 file changed, 216 deletions(-) delete mode 100644 onnxruntime/test/providers/cpu/math/gemm_large_test.cc diff --git a/onnxruntime/test/providers/cpu/math/gemm_large_test.cc b/onnxruntime/test/providers/cpu/math/gemm_large_test.cc deleted file mode 100644 index 54c2491a0f8a1..0000000000000 --- a/onnxruntime/test/providers/cpu/math/gemm_large_test.cc +++ /dev/null @@ -1,216 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "gtest/gtest.h" - -#include "test/providers/provider_test_utils.h" -#include "test/common/tensor_op_test_utils.h" -#include "default_providers.h" - -namespace onnxruntime { -namespace test { - -bool IsValidBroadcast(const TensorShape& bias_shape, int64_t M, int64_t N) { - // valid shapes are (,) , (1, N) , (M, 1) , (M, N) - if (bias_shape.NumDimensions() > 2) - return false; - // shape is (1,) or (1, 1), or (,) - if (bias_shape.Size() == 1) - return true; - // valid bias_shape (s) are (N,) or (1, N) or (M, 1) or (M, N), - // In last case no broadcasting needed, so don't fail it - return ((bias_shape.NumDimensions() == 1 && bias_shape[0] == N) || - (bias_shape.NumDimensions() == 2 && bias_shape[0] == M && (bias_shape[1] == 1 || bias_shape[1] == N)) || - (bias_shape.NumDimensions() == 2 && bias_shape[0] == 1 && bias_shape[1] == N)); -} - -Status ComputeGemmOutputShape(const TensorShape& left, int64_t trans_left, const TensorShape& right, - int64_t trans_right, const TensorShape& bias, int64_t& M, int64_t& K, int64_t& N) { - // dimension check - ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1); - ORT_ENFORCE(right.NumDimensions() == 2); - - for (size_t i = 0; i != left.NumDimensions(); ++i) { - ORT_ENFORCE(left[i] >= 0); - ORT_ENFORCE(left[i] <= std::numeric_limits::max()); - } - - for (size_t i = 0; i != right.NumDimensions(); ++i) { - ORT_ENFORCE(right[i] >= 0); - ORT_ENFORCE(right[i] <= std::numeric_limits::max()); - } - - if (trans_left == 1) { - M = left.NumDimensions() == 2 ? left[1] : left[0]; - K = left.NumDimensions() == 2 ? left[0] : 1; - } else { - M = left.NumDimensions() == 2 ? left[0] : 1; - K = left.NumDimensions() == 2 ? left[1] : left[0]; - } - - N = trans_right == 1 ? N = right[0] : N = right[1]; - int k_dim = trans_right == 1 ? 1 : 0; - - Status status; - if (right[k_dim] != K) { - status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "GEMM: Dimension mismatch, W: ", right.ToString(), - " K: " + std::to_string(K), " N:" + std::to_string(N)); - return status; - } - - if (!IsValidBroadcast(bias, M, N)) { - status = common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Gemm: Invalid bias shape for broadcast"); - return status; - } - - // it is possible the input is empty tensor, for example the output of roipool in fast rcnn. - // it is also possible that K == 0 - ORT_ENFORCE(M >= 0 && K >= 0 && N >= 0); - - return status; -} - -float GetScale(const std::vector& c_vals, const TensorShape& c_shape, int64_t M, int64_t N, int64_t m, int64_t n) { - if (c_vals.empty()) - return 0.0f; - if (c_shape.Size() == 1) - return c_vals[0]; - // valid c_shape (s) are (N,) or (1, N) or (M, 1) or (M, N), - // In last case no broadcasting needed, so don't fail it - if (c_shape.NumDimensions() == 1 && c_shape[0] == N) { - return c_vals[n]; - } - - if (c_shape.NumDimensions() == 2 && c_shape[0] == M) { - if (c_shape[1] == 1) { - return c_vals[m]; - } else if (c_shape[1] == N) { - return c_vals[m * N + n]; - } - } - - if (c_shape.NumDimensions() == 2 && c_shape[0] == 1 && c_shape[1] == N) { - return c_vals[n]; - } - return 0.0f; -} - -Status GetExpectedResult(const int64_t M, const int64_t K, const int64_t N, const std::vector& a_vals, - const std::vector& b_vals, const std::vector& c_vals, - std::vector& expected_vals, const TensorShape& a_shape, int64_t a_trans, - const TensorShape& b_shape, int64_t b_trans, - const TensorShape& c_shape, float alpha, float beta) { - for (int64_t m = 0; m < M; m++) { - for (int64_t n = 0; n < N; n++) { - float sum = 0.0f; - for (int64_t k = 0; k < K; k++) { - if (a_trans == 0 && b_trans == 0) { - sum += a_vals[m * K + k] * b_vals[k * N + n]; - } else if (a_trans == 0 && b_trans == 1) { - sum += a_vals[m * K + k] * b_vals[n * K + k]; - } else if (a_trans == 1 && b_trans == 0) { - sum += a_vals[k * M + m] * b_vals[k * N + n]; - } else { - sum += a_vals[k * M + m] * b_vals[n * K + k]; - } - } - expected_vals[m * N + n] = sum * alpha + GetScale(c_vals, c_shape, M, N, m, n) * beta; - } - } - - return Status::OK(); -} - -template -void RunTestTyped(std::initializer_list a_dims, int64_t a_trans, std::initializer_list b_dims, - int64_t b_trans, std::initializer_list c_dims, float alpha = 1.0f, float beta = 1.0f) { - static_assert(std::is_same_v || std::is_same_v, "unexpected type for T1"); - - int64_t M = 0; - int64_t K = 0; - int64_t N = 0; - TensorShape a_shape = TensorShape(a_dims); - TensorShape b_shape = TensorShape(b_dims); - TensorShape c_shape = TensorShape(c_dims); - ComputeGemmOutputShape(a_shape, a_trans, b_shape, b_trans, c_shape, M, K, N); - - RandomValueGenerator random{1234}; - std::vector a_vals(random.Gaussian(AsSpan(a_dims), 0.0f, 0.25f)); - std::vector b_vals(random.Gaussian(AsSpan(b_dims), 0.0f, 0.25f)); - std::vector c_vals; - if (c_dims.size() > 0) { - c_vals = std::vector(random.Gaussian(AsSpan(c_dims), 0.0f, 0.25f)); - } - std::vector expected_vals(M * N); - GetExpectedResult(M, K, N, a_vals, b_vals, c_vals, expected_vals, a_shape, a_trans, b_shape, b_trans, c_shape, alpha, beta); - - OpTester test("Gemm", version); - test.AddAttribute("transA", a_trans); - test.AddAttribute("transB", b_trans); - test.AddAttribute("alpha", alpha); - test.AddAttribute("beta", beta); - if constexpr (std::is_same_v) { - test.AddInput("A", a_dims, a_vals); - test.AddInput("B", b_dims, b_vals); - if (c_dims.size() != 0) { - test.AddInput("C", c_dims, c_vals); - } - test.AddOutput("Y", {M, N}, expected_vals); - } else if constexpr (std::is_same::value) { - test.AddInput("A", a_dims, FloatsToMLFloat16s(a_vals)); - test.AddInput("B", b_dims, FloatsToMLFloat16s(b_vals)); - if (c_dims.size() != 0) { - test.AddInput("C", c_dims, FloatsToMLFloat16s(c_vals)); - } - test.AddOutput("Y", {M, N}, FloatsToMLFloat16s(expected_vals)); - test.SetOutputAbsErr("Y", 0.055f); - test.SetOutputRelErr("Y", 0.02f); - } - - test.RunWithConfig(); -} - -TEST(Gemm_Large, Float32_SplitK) { - RunTestTyped({16, 1024}, 0, {1024, 191}, 0, {1, 191}); - RunTestTyped({15, 1024}, 0, {1024, 191}, 0, {15, 191}); - RunTestTyped({15, 1024}, 0, {1024, 192}, 0, {15, 1}); - - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}); - - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}, 1.5f, 1.3f); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}, 1.5f, 1.3f); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}, 1.5f, 1.3f); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}, 1.5f, 1.3f); - - RunTestTyped({1024, 16}, 1, {1024, 192}, 0, {192}); - RunTestTyped({16, 1024}, 0, {192, 1024}, 1, {192}); - RunTestTyped({1024, 16}, 1, {192, 1024}, 1, {192}); -} - -TEST(Gemm_Large, Float16_SplitK) { - RunTestTyped({16, 1024}, 0, {1024, 191}, 0, {1, 191}); - RunTestTyped({15, 1024}, 0, {1024, 191}, 0, {15, 191}); - RunTestTyped({15, 1024}, 0, {1024, 192}, 0, {15, 1}); - - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}); - - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 1}, 1.5f, 1.3f); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {1, 192}, 1.5f, 1.3f); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {192}, 1.5f, 1.3f); - RunTestTyped({16, 1024}, 0, {1024, 192}, 0, {16, 192}, 1.5f, 1.3f); - - RunTestTyped({1024, 16}, 1, {1024, 192}, 0, {192}); - RunTestTyped({16, 1024}, 0, {192, 1024}, 1, {192}); - RunTestTyped({1024, 16}, 1, {192, 1024}, 1, {192}); -} - -} // namespace test -} // namespace onnxruntime From e0c15fb34fb0e34ad4d6f036b94799f6045c5746 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Tue, 9 Dec 2025 11:40:55 +0800 Subject: [PATCH 5/6] Add cases to `GemmOptimizePackedTest` for Split-K --- onnxruntime/test/providers/cpu/math/gemm_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index d7d9d2994afa1..f55bb2ef4aad8 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -1508,7 +1508,7 @@ TEST_P(GemmOptimizePackedTest, TestVariants) { std::vector GenerateGemmParams() { std::vector params; - std::vector> test_sizes = {{1, 1, 1}, {1, 64, 448}, {2, 3, 4}, {8, 8, 8}, {31, 31, 31}, {32, 32, 32}, {33, 67, 99}, {37, 64, 256}, {48, 48, 120}, {60, 16, 92}, {63, 64, 65}, {64, 64, 64}, {64, 64, 65}, {72, 80, 84}, {96, 24, 48}, {128, 32, 64}, {128, 128, 128}, {129, 129, 129}, {256, 64, 1024}}; + std::vector> test_sizes = {{1, 1, 1}, {1, 64, 448}, {2, 3, 4}, {8, 8, 8}, {31, 31, 31}, {32, 32, 32}, {33, 67, 99}, {37, 64, 256}, {48, 48, 120}, {60, 16, 92}, {63, 64, 65}, {64, 64, 64}, {64, 64, 65}, {72, 80, 84}, {96, 24, 48}, {128, 32, 64}, {128, 128, 128}, {129, 129, 129}, {256, 64, 1024}, {16, 768, 192}, {15, 768, 192}, {16, 768, 191}}; std::vector bias_types = {BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, BiasType::MNBias, BiasType::NBias}; From a86ba66d8ff10b7fa9e72dbeff4dc48b1ed3ca65 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Fri, 12 Dec 2025 09:46:07 +0800 Subject: [PATCH 6/6] Address reviewer's comments --- onnxruntime/core/providers/webgpu/math/gemm_packed.cc | 2 +- onnxruntime/core/providers/webgpu/math/matmul.cc | 11 +---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 276547bbba5c9..1a0ad7a843ec4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -117,7 +117,7 @@ Status ApplyGemmPacked(const Tensor* a, bias = c; output_components_in_fill_bias_program = c_components; } - const TensorShape output_shape = TensorShape{1, M, N / output_components_in_fill_bias_program}; + const TensorShape output_shape = TensorShape{M, N / output_components_in_fill_bias_program}; auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram( bias, y, /*is_gemm*/ true, beta, output_components_in_fill_bias_program, c_is_scalar, output_shape); diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 5cbbff04ac819..5dc2dfdfb013a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -304,20 +304,11 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr const uint32_t total_outputs = dim_a_outer * dim_b_outer; const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; - TensorShape output_shape_temp; - if (is_gemm) { - // GEMM doesn't have `batch` in its output shape. - output_shape_temp = TensorShape({dim_a_outer, dim_b_outer}); - } else { - const uint32_t batch_size = 1; - output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer}); - } - // To reuse `MatMulWriteFnSource()` we need to set `dim_b_outer` in components when `output_shape` // is in `vec4`, while use `output_shape` directly as the output shape. const uint32_t dim_b_outer_components = narrow(dim_b_outer * output_components); program.CacheHint(is_gemm, has_bias, output_components, bias_is_scalar) - .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_temp, static_cast(output_components)}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, static_cast(output_components)}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}}) .SetDispatchGroupSize(dispatch_x);