diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index b81977883dd70..1a0ad7a843ec4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -5,6 +5,7 @@ #include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/math/matmul.h" #include "core/providers/webgpu/math/matmul_utils.h" #include "core/providers/webgpu/math/gemm_utils.h" @@ -12,7 +13,14 @@ namespace onnxruntime { namespace webgpu { Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { - const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const bool need_split_k = NeedSplitK(); + ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias; + if (need_split_k) { + // When Split-K is enabled, we will declare output as `atomic` to call atomic built-in + // functions on it, so we need below information to correctly compute the index on the output. + output_usage |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride; + } + const ShaderVariableHelper& output = shader.AddOutput("output", output_usage); // Each thread compute 4*4 elements InlinedVector elements_per_thread = InlinedVector({4, 4, 1}); @@ -26,7 +34,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_); } if (is_vec4_) { - ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_)); + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_)); } else { ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_)); } @@ -35,11 +43,17 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { if (need_handle_bias_) { c = &shader.AddInput("c", ShaderUsage::UseUniform); } - MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_); + + const ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; + MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type); return Status::OK(); } +bool GemmProgram::NeedSplitK() const { + return split_dim_inner_ > 1; +} + Status ApplyGemmPacked(const Tensor* a, const Tensor* b, const Tensor* c, @@ -86,7 +100,44 @@ Status ApplyGemmPacked(const Tensor* a, c_is_scalar = c_shape.Size() == 1; } - GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4}; + ProgramOutput output(y, ProgramTensorMetadataDependency::TypeAndRank, output_components); + uint32_t dispatch_z = 1; + uint32_t split_dim_inner = 1; + + const SplitKConfig& split_k_config = context.GetSplitKConfig(); + // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. + const bool output_is_vec4 = output_components == 4; + // The parameter `is_channel_last` is not used for GEMM. + const bool need_split_k = split_k_config.UseSplitK( + is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K); + if (need_split_k) { + const Tensor* bias = nullptr; + uint32_t output_components_in_fill_bias_program = 4; + if (need_handle_bias) { + bias = c; + output_components_in_fill_bias_program = c_components; + } + const TensorShape output_shape = TensorShape{M, N / output_components_in_fill_bias_program}; + + auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram( + bias, y, /*is_gemm*/ true, beta, output_components_in_fill_bias_program, c_is_scalar, output_shape); + ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program)); + + // When Split-K is used, `bias` will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram` + // instead of here. + need_handle_bias = false; + + // With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the + // number of splits along `dim_inner`. + split_dim_inner = split_k_config.GetSplitDimInner(); + dispatch_z = (K + split_dim_inner - 1) / split_dim_inner; + + // The output should be declared in atomic types in `MatMulProgram` for the use of atomic + // built-in functions. + output.is_atomic = true; + } + + GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4, split_dim_inner}; if (need_handle_matmul) { program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -101,9 +152,9 @@ Status ApplyGemmPacked(const Tensor* a, const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE; const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE; - program.CacheHint(alpha, transA, transB, c_is_scalar) - .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, 1u) + program.CacheHint(alpha, transA, transB, c_is_scalar, split_dim_inner) + .AddOutput(std::move(output)) + .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) .SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z) .AddUniformVariables({{alpha}, {beta}, @@ -112,7 +163,7 @@ Status ApplyGemmPacked(const Tensor* a, {K}, /*dim_inner */ {dispatch_x}, /* logical_dispatch_x */ {dispatch_y}, /* logical_dispatch_y */ - {1u}} /* logical_dispatch_z */ + {dispatch_z}} /* logical_dispatch_z */ ); return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.h b/onnxruntime/core/providers/webgpu/math/gemm_packed.h index cb89ccefba313..f81da43e3fe36 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.h @@ -13,7 +13,7 @@ namespace webgpu { class GemmProgram final : public Program { public: - GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, int c_components, bool c_is_scalar, int output_components, bool is_vec4 = false) + GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, int c_components, bool c_is_scalar, int output_components, bool is_vec4 = false, uint32_t split_dim_inner = 1) : Program{"Gemm"}, transA_{transA}, transB_{transB}, @@ -23,10 +23,13 @@ class GemmProgram final : public Program { c_components_(c_components), c_is_scalar_(c_is_scalar), output_components_(output_components), - is_vec4_(is_vec4) {} + is_vec4_(is_vec4), + split_dim_inner_(split_dim_inner) {} Status GenerateShaderCode(ShaderHelper& sh) const override; + bool NeedSplitK() const; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"alpha", ProgramUniformVariableDataType::Float32}, {"beta", ProgramUniformVariableDataType::Float32}, @@ -51,6 +54,7 @@ class GemmProgram final : public Program { bool c_is_scalar_ = false; int output_components_; bool is_vec4_ = false; + uint32_t split_dim_inner_ = 1; }; Status ApplyGemmPacked(const Tensor* a, diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 73242ed3ff1ba..ba7e9290f8455 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -54,9 +54,14 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, void HandleMatMulWithSplitK( ShaderHelper& shader, + bool is_gemm, const ShaderVariableHelper& output, ProgramVariableDataType output_variable_type) { - shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; + if (is_gemm) { + shader.AdditionalImplementation() << " let coords = vec2(u32(row), u32(colIn));"; + } else { + shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n"; + } // With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups, // so we must add them with atomic built-in functions. Because currently WebGPU doesn't support @@ -205,8 +210,10 @@ void MatMulWriteFnSource(ShaderHelper& shader, // still need to handle `bias` (and `is_channels_last` in the future) in // `MatMulFillBiasOrZeroBeforeSplitKProgram`. ORT_ENFORCE(bias == nullptr, "Bias is not supported in MatMulProgram when Split-K is enabled."); - ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled."); - HandleMatMulWithSplitK(shader, output, output_variable_type); + if (!is_gemm) { + ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled in non-GEMM ops."); + } + HandleMatMulWithSplitK(shader, is_gemm, output, output_variable_type); } else if (is_gemm) { HandleMaybeHaveBiasForGEMM(shader, output, bias, c_components, output_components, c_is_scalar); } else { diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 72dd235eb820a..5dc2dfdfb013a 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -240,14 +240,14 @@ Status ComputeMatMul(ComputeContext* context, uint32_t split_dim_inner = 1; const SplitKConfig& split_k_config = context->GetSplitKConfig(); - const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. - const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp); + const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, /*bias_is_scalar*/ false, output_shape_temp); ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program)); // `bias` has been handled in the execution of `fill_bias_program` so we don't need to set @@ -287,32 +287,35 @@ Status ComputeMatMul(ComputeContext* context, MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, - const TensorShape& output_shape_vec4) { + bool is_gemm, + float beta, + uint32_t output_components, + bool bias_is_scalar, + const TensorShape& output_shape) { const bool has_bias = bias != nullptr; - // Currently we only support bias in vec4 and channels last format for Split-K MatMul. - constexpr uint32_t bias_components = 4; - MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias); + // Currently we only support GEMM and channels last format for MatMul with Split-K. + MatMulFillBiasOrZeroBeforeSplitKProgram program(is_gemm, has_bias, output_components, bias_is_scalar); - const uint32_t dim_a_outer = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]); - const uint32_t dim_b_outer_vec4 = narrow(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]); + const uint32_t dim_a_outer = narrow(output_shape[output_shape.NumDimensions() - 2]); + const uint32_t dim_b_outer = narrow(output_shape[output_shape.NumDimensions() - 1]); - // Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for - // this program. - const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4; - const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; + // Fill one value per invocation. Now we use default workgroup size (64) for this program. + const uint32_t total_outputs = dim_a_outer * dim_b_outer; + const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE; - // To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar - // instead of vec4, while use `output_shape_vec4` directly as the output shape. - const uint32_t dim_b_outer = narrow(dim_b_outer_vec4 * bias_components); - program.CacheHint(has_bias) - .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast(bias_components)}) - .AddUniformVariables({{dim_a_outer}, {dim_b_outer}}) + // To reuse `MatMulWriteFnSource()` we need to set `dim_b_outer` in components when `output_shape` + // is in `vec4`, while use `output_shape` directly as the output shape. + const uint32_t dim_b_outer_components = narrow(dim_b_outer * output_components); + program.CacheHint(is_gemm, has_bias, output_components, bias_is_scalar) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, static_cast(output_components)}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}}) .SetDispatchGroupSize(dispatch_x); if (has_bias) { - const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); - program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(bias_components)}); + // We always use `c_components` as `output_components` in GEMM, and 4 in MatMul. + const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), output_components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast(output_components)}); } return program; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 0b65827be7f17..a1b7a1d34f2ca 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -21,7 +21,11 @@ Status ComputeMatMul(ComputeContext* context, const Activation& activation, std: MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram( const Tensor* bias, Tensor* output, - const TensorShape& output_shape_vec4); + bool is_gemm, + float beta, + uint32_t output_components, + bool bias_is_scalar, + const TensorShape& output_shape); class MatMul final : public WebGpuKernel { public: diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 80a110c3b505c..e97e0fd6f1058 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -63,13 +63,13 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& // Handle bias with `MatMulWriteFnSource()`. // Here `use_split_k` is false because we just initialize `output` with bias. // `use_split_k` is true only when we do the actual MatMul with Split-K. - // Currently we only support bias in vec4 and channels last format for Split-K MatMul. + const uint32_t bias_components = output_components_; MatMulWriteFnSource( - shader, output, bias, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false, + shader, output, bias, is_gemm_, bias_components, output_components_, bias_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); + shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n"; shader.MainFunctionBody() << R"( - let output_components = 4; let output_id = i32(global_idx); let dim_a_outer = i32(uniforms.dim_a_outer); diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index dbd193bc38f58..618fc97d72fe0 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -46,18 +46,25 @@ class MatMulProgram final : public Program { // the output with 0 or bias first to make sure `atomicLoad` won't return garbage data. class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program { public: - explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool has_bias) + MatMulFillBiasOrZeroBeforeSplitKProgram(bool is_gemm, bool has_bias, uint32_t output_components, bool bias_is_scalar) : Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"}, - has_bias_(has_bias) { + is_gemm_(is_gemm), + has_bias_(has_bias), + output_components_(output_components), + bias_is_scalar_(bias_is_scalar) { } Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, - {"dim_b_outer", ProgramUniformVariableDataType::Uint32}); + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"beta", ProgramUniformVariableDataType::Float32}); private: + bool is_gemm_ = false; bool has_bias_ = false; + uint32_t output_components_ = 0; + bool bias_is_scalar_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index 824cfb02c22f0..4386bdcc94056 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -49,6 +49,7 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_gemm, bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, @@ -64,8 +65,8 @@ bool SplitKConfig::UseSplitK( use_split_k &= is_vec4; use_split_k &= batch_size == 1; // Now `is_channels_last` is only supported because we only generate vec4 shaders in - // `MatMulFillBiasOrZeroBeforeSplitKProgram`. - use_split_k &= is_channels_last; + // `MatMulFillBiasOrZeroBeforeSplitKProgram` when `is_gemm` is false. + use_split_k &= (is_channels_last || is_gemm); // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 0aa47371f6752..960c0b565fa66 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -101,7 +101,7 @@ class SplitKConfig { explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( - bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool is_gemm, bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner) const; diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index d7d9d2994afa1..f55bb2ef4aad8 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -1508,7 +1508,7 @@ TEST_P(GemmOptimizePackedTest, TestVariants) { std::vector GenerateGemmParams() { std::vector params; - std::vector> test_sizes = {{1, 1, 1}, {1, 64, 448}, {2, 3, 4}, {8, 8, 8}, {31, 31, 31}, {32, 32, 32}, {33, 67, 99}, {37, 64, 256}, {48, 48, 120}, {60, 16, 92}, {63, 64, 65}, {64, 64, 64}, {64, 64, 65}, {72, 80, 84}, {96, 24, 48}, {128, 32, 64}, {128, 128, 128}, {129, 129, 129}, {256, 64, 1024}}; + std::vector> test_sizes = {{1, 1, 1}, {1, 64, 448}, {2, 3, 4}, {8, 8, 8}, {31, 31, 31}, {32, 32, 32}, {33, 67, 99}, {37, 64, 256}, {48, 48, 120}, {60, 16, 92}, {63, 64, 65}, {64, 64, 64}, {64, 64, 65}, {72, 80, 84}, {96, 24, 48}, {128, 32, 64}, {128, 128, 128}, {129, 129, 129}, {256, 64, 1024}, {16, 768, 192}, {15, 768, 192}, {16, 768, 191}}; std::vector bias_types = {BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, BiasType::MNBias, BiasType::NBias};