Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 59 additions & 8 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,22 @@

#include "core/providers/webgpu/webgpu_utils.h"

#include "core/providers/webgpu/math/matmul.h"
#include "core/providers/webgpu/math/matmul_utils.h"
#include "core/providers/webgpu/math/gemm_utils.h"

namespace onnxruntime {
namespace webgpu {

Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const bool need_split_k = NeedSplitK();
ShaderUsage output_usage = ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias;
if (need_split_k) {
// When Split-K is enabled, we will declare output as `atomic<i32>` to call atomic built-in
// functions on it, so we need below information to correctly compute the index on the output.
output_usage |= ShaderUsage::UseIndicesToOffset | ShaderUsage::UseShapeAndStride;
}
const ShaderVariableHelper& output = shader.AddOutput("output", output_usage);

// Each thread compute 4*4 elements
InlinedVector<int64_t> elements_per_thread = InlinedVector<int64_t>({4, 4, 1});
Expand All @@ -26,7 +34,7 @@
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_);
}
if (is_vec4_) {
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_));
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
} else {
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_));
}
Expand All @@ -35,11 +43,17 @@
if (need_handle_bias_) {
c = &shader.AddInput("c", ShaderUsage::UseUniform);
}
MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_);

const ProgramVariableDataType output_var_type = this->Outputs()[0].var_type;
MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type);

return Status::OK();
}

bool GemmProgram::NeedSplitK() const {
return split_dim_inner_ > 1;
}

Status ApplyGemmPacked(const Tensor* a,
const Tensor* b,
const Tensor* c,
Expand Down Expand Up @@ -86,7 +100,44 @@
c_is_scalar = c_shape.Size() == 1;
}

GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4};
ProgramOutput output(y, ProgramTensorMetadataDependency::TypeAndRank, output_components);
uint32_t dispatch_z = 1;
uint32_t split_dim_inner = 1;

const SplitKConfig& split_k_config = context.GetSplitKConfig();
// Currently we require the components for Y must also be a multiple of 4 when Split-K is used.
const bool output_is_vec4 = output_components == 4;
// The parameter `is_channel_last` is not used for GEMM.
const bool need_split_k = split_k_config.UseSplitK(
is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K);
if (need_split_k) {
const Tensor* bias = nullptr;
uint32_t output_components_in_fill_bias_program = 4;
if (need_handle_bias) {
bias = c;
output_components_in_fill_bias_program = c_components;
}
const TensorShape output_shape = TensorShape{M, N / output_components_in_fill_bias_program};

auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
bias, y, /*is_gemm*/ true, beta, output_components_in_fill_bias_program, c_is_scalar, output_shape);
ORT_RETURN_IF_ERROR(context.RunProgram(fill_bias_program));

// When Split-K is used, `bias` will be handled in `MatMulFillBiasOrZeroBeforeSplitKProgram`
// instead of here.
need_handle_bias = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (K + split_dim_inner - 1) / split_dim_inner;

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
output.is_atomic = true;
}

GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4, split_dim_inner};

if (need_handle_matmul) {
program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, components},
Expand All @@ -101,9 +152,9 @@
const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE;

program.CacheHint(alpha, transA, transB, c_is_scalar)
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
.SetDispatchGroupSize(dispatch_x, dispatch_y, 1u)
program.CacheHint(alpha, transA, transB, c_is_scalar, split_dim_inner)
.AddOutput(std::move(output))

Check warning on line 156 in onnxruntime/core/providers/webgpu/math/gemm_packed.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/gemm_packed.cc:156: Add #include <utility> for move [build/include_what_you_use] [4]
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z)
.AddUniformVariables({{alpha},
{beta},
Expand All @@ -112,7 +163,7 @@
{K}, /*dim_inner */
{dispatch_x}, /* logical_dispatch_x */
{dispatch_y}, /* logical_dispatch_y */
{1u}} /* logical_dispatch_z */
{dispatch_z}} /* logical_dispatch_z */
);

return context.RunProgram(program);
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace webgpu {

class GemmProgram final : public Program<GemmProgram> {
public:
GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, int c_components, bool c_is_scalar, int output_components, bool is_vec4 = false)
GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, int c_components, bool c_is_scalar, int output_components, bool is_vec4 = false, uint32_t split_dim_inner = 1)
: Program{"Gemm"},
transA_{transA},
transB_{transB},
Expand All @@ -23,10 +23,13 @@ class GemmProgram final : public Program<GemmProgram> {
c_components_(c_components),
c_is_scalar_(c_is_scalar),
output_components_(output_components),
is_vec4_(is_vec4) {}
is_vec4_(is_vec4),
split_dim_inner_(split_dim_inner) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

bool NeedSplitK() const;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"alpha", ProgramUniformVariableDataType::Float32},
{"beta", ProgramUniformVariableDataType::Float32},
Expand All @@ -51,6 +54,7 @@ class GemmProgram final : public Program<GemmProgram> {
bool c_is_scalar_ = false;
int output_components_;
bool is_vec4_ = false;
uint32_t split_dim_inner_ = 1;
};

Status ApplyGemmPacked(const Tensor* a,
Expand Down
13 changes: 10 additions & 3 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,

void HandleMatMulWithSplitK(
ShaderHelper& shader,
bool is_gemm,
const ShaderVariableHelper& output,
ProgramVariableDataType output_variable_type) {
shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n";
if (is_gemm) {
shader.AdditionalImplementation() << " let coords = vec2(u32(row), u32(colIn));";
} else {
shader.AdditionalImplementation() << " let coords = vec3(u32(batch), u32(row), u32(colIn));\n";
}

// With Split-K, the final output will be the sum of the sub-outputs from multiple workgroups,
// so we must add them with atomic built-in functions. Because currently WebGPU doesn't support
Expand Down Expand Up @@ -205,8 +210,10 @@ void MatMulWriteFnSource(ShaderHelper& shader,
// still need to handle `bias` (and `is_channels_last` in the future) in
// `MatMulFillBiasOrZeroBeforeSplitKProgram`.
ORT_ENFORCE(bias == nullptr, "Bias is not supported in MatMulProgram when Split-K is enabled.");
ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled.");
HandleMatMulWithSplitK(shader, output, output_variable_type);
if (!is_gemm) {
ORT_ENFORCE(is_channels_last, "Only channels-last is supported in MatMulProgram when Split-K is enabled in non-GEMM ops.");
}
HandleMatMulWithSplitK(shader, is_gemm, output, output_variable_type);
} else if (is_gemm) {
HandleMaybeHaveBiasForGEMM(shader, output, bias, c_components, output_components, c_is_scalar);
} else {
Expand Down
43 changes: 23 additions & 20 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,14 @@ Status ComputeMatMul(ComputeContext* context,
uint32_t split_dim_inner = 1;

const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, output_shape_temp);
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, /*bias_is_scalar*/ false, output_shape_temp);
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));

// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
Expand Down Expand Up @@ -287,32 +287,35 @@ Status ComputeMatMul(ComputeContext* context,
MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
const Tensor* bias,
Tensor* output,
const TensorShape& output_shape_vec4) {
bool is_gemm,
float beta,
uint32_t output_components,
bool bias_is_scalar,
const TensorShape& output_shape) {
const bool has_bias = bias != nullptr;

// Currently we only support bias in vec4 and channels last format for Split-K MatMul.
constexpr uint32_t bias_components = 4;
MatMulFillBiasOrZeroBeforeSplitKProgram program(has_bias);
// Currently we only support GEMM and channels last format for MatMul with Split-K.
MatMulFillBiasOrZeroBeforeSplitKProgram program(is_gemm, has_bias, output_components, bias_is_scalar);

const uint32_t dim_a_outer = narrow<uint32_t>(output_shape_vec4[output_shape_vec4.NumDimensions() - 2]);
const uint32_t dim_b_outer_vec4 = narrow<uint32_t>(output_shape_vec4[output_shape_vec4.NumDimensions() - 1]);
const uint32_t dim_a_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 2]);
const uint32_t dim_b_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 1]);

// Fill one value (currently only vec4) per invocation. Now we use default workgroup size (64) for
// this program.
const uint32_t total_outputs_vec4 = dim_a_outer * dim_b_outer_vec4;
const uint32_t dispatch_x = (total_outputs_vec4 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
// Fill one value per invocation. Now we use default workgroup size (64) for this program.
const uint32_t total_outputs = dim_a_outer * dim_b_outer;
const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;

// To reuse `MatMulWriteFnSource()` we need to set `dim_a_outer` and `dim_b_outer` in scalar
// instead of vec4, while use `output_shape_vec4` directly as the output shape.
const uint32_t dim_b_outer = narrow<uint32_t>(dim_b_outer_vec4 * bias_components);
program.CacheHint(has_bias)
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape_vec4, static_cast<int32_t>(bias_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}})
// To reuse `MatMulWriteFnSource()` we need to set `dim_b_outer` in components when `output_shape`
// is in `vec4`, while use `output_shape` directly as the output shape.
const uint32_t dim_b_outer_components = narrow<uint32_t>(dim_b_outer * output_components);
program.CacheHint(is_gemm, has_bias, output_components, bias_is_scalar)
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, static_cast<int32_t>(output_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}})
.SetDispatchGroupSize(dispatch_x);

if (has_bias) {
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(bias_components)});
// We always use `c_components` as `output_components` in GEMM, and 4 in MatMul.
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), output_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(output_components)});
}

return program;
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/webgpu/math/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ Status ComputeMatMul(ComputeContext* context, const Activation& activation, std:
MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKProgram(
const Tensor* bias,
Tensor* output,
const TensorShape& output_shape_vec4);
bool is_gemm,
float beta,
uint32_t output_components,
bool bias_is_scalar,
const TensorShape& output_shape);

class MatMul final : public WebGpuKernel {
public:
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
// Handle bias with `MatMulWriteFnSource()`.
// Here `use_split_k` is false because we just initialize `output` with bias.
// `use_split_k` is true only when we do the actual MatMul with Split-K.
// Currently we only support bias in vec4 and channels last format for Split-K MatMul.
const uint32_t bias_components = output_components_;
MatMulWriteFnSource(
shader, output, bias, /*is_gemm*/ false, /*c_components*/ 4, /*output_components*/ 4, /*c_is_scalar*/ false,
shader, output, bias, is_gemm_, bias_components, output_components_, bias_is_scalar_,
/*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false);

shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n";
shader.MainFunctionBody() << R"(
let output_components = 4;
let output_id = i32(global_idx);

let dim_a_outer = i32(uniforms.dim_a_outer);
Expand Down
13 changes: 10 additions & 3 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,25 @@ class MatMulProgram final : public Program<MatMulProgram> {
// the output with 0 or bias first to make sure `atomicLoad` won't return garbage data.
class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program<MatMulFillBiasOrZeroBeforeSplitKProgram> {
public:
explicit MatMulFillBiasOrZeroBeforeSplitKProgram(bool has_bias)
MatMulFillBiasOrZeroBeforeSplitKProgram(bool is_gemm, bool has_bias, uint32_t output_components, bool bias_is_scalar)
: Program{"MatMul_Fill_Bias_Or_Zero_Before_Split_K"},
has_bias_(has_bias) {
is_gemm_(is_gemm),
has_bias_(has_bias),
output_components_(output_components),
bias_is_scalar_(bias_is_scalar) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32});
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"beta", ProgramUniformVariableDataType::Float32});

private:
bool is_gemm_ = false;
bool has_bias_ = false;
uint32_t output_components_ = 0;
bool bias_is_scalar_ = false;
};

} // namespace webgpu
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ bool SplitKConfig::UseSplitK(
bool is_vec4,
ActivationKind activation_kind,
uint64_t batch_size,
bool is_gemm,
bool is_channels_last,
uint32_t dim_a_outer,
uint32_t dim_b_outer,
Expand All @@ -64,8 +65,8 @@ bool SplitKConfig::UseSplitK(
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;
// Now `is_channels_last` is only supported because we only generate vec4 shaders in
// `MatMulFillBiasOrZeroBeforeSplitKProgram`.
use_split_k &= is_channels_last;
// `MatMulFillBiasOrZeroBeforeSplitKProgram` when `is_gemm` is false.
use_split_k &= (is_channels_last || is_gemm);

// Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class SplitKConfig {
explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info);

bool UseSplitK(
bool is_vec4, ActivationKind activation_kind, uint64_t batch_size,
bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool is_gemm,
bool is_channels_last, uint32_t dim_a_outer,
uint32_t dim_b_outer, uint32_t dim_inner) const;

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/cpu/math/gemm_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,7 @@ TEST_P(GemmOptimizePackedTest, TestVariants) {
std::vector<GemmOptimizePackedParams> GenerateGemmParams() {
std::vector<GemmOptimizePackedParams> params;

std::vector<std::tuple<int64_t, int64_t, int64_t>> test_sizes = {{1, 1, 1}, {1, 64, 448}, {2, 3, 4}, {8, 8, 8}, {31, 31, 31}, {32, 32, 32}, {33, 67, 99}, {37, 64, 256}, {48, 48, 120}, {60, 16, 92}, {63, 64, 65}, {64, 64, 64}, {64, 64, 65}, {72, 80, 84}, {96, 24, 48}, {128, 32, 64}, {128, 128, 128}, {129, 129, 129}, {256, 64, 1024}};
std::vector<std::tuple<int64_t, int64_t, int64_t>> test_sizes = {{1, 1, 1}, {1, 64, 448}, {2, 3, 4}, {8, 8, 8}, {31, 31, 31}, {32, 32, 32}, {33, 67, 99}, {37, 64, 256}, {48, 48, 120}, {60, 16, 92}, {63, 64, 65}, {64, 64, 64}, {64, 64, 65}, {72, 80, 84}, {96, 24, 48}, {128, 32, 64}, {128, 128, 128}, {129, 129, 129}, {256, 64, 1024}, {16, 768, 192}, {15, 768, 192}, {16, 768, 191}};

std::vector<BiasType>
bias_types = {BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, BiasType::MNBias, BiasType::NBias};
Expand Down
Loading