Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_);
}
if (is_vec4_) {
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
} else {
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_));
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, transA_, transB_, alpha_, need_handle_matmul_));
}

const ShaderVariableHelper* c = nullptr;
Expand Down Expand Up @@ -158,7 +158,8 @@ Status ApplyGemmPacked(const Tensor* a,
const uint32_t dispatch_x = (N + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t dispatch_y = (M + TILE_SIZE - 1) / TILE_SIZE;

program.CacheHint(alpha, transA, transB, c_is_scalar, split_dim_inner)
const bool has_batch_dims = false; // GemmProgram::GenerateShaderCode passes nullptr for batch_dims.
program.CacheHint(alpha, transA, transB, c_is_scalar, split_dim_inner, has_batch_dims)
.AddOutput(std::move(output))
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z)
Expand Down
28 changes: 21 additions & 7 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,27 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
// atomic built-in functions in `HandleMatMulWithSplitK()`.
shader.MainFunctionBody()
<< "const kSplitK = " << split_dim_inner << ";\n"
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"

// When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate
// the index of split-k instead of batch.
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n";
if (nullptr != batch_dims) {
Comment thread
Jiawei-Shao marked this conversation as resolved.
// With Split-K and batch (in MatMul and Conv2D|MatMul), `dispatch_z` is
// `splits_per_batch * batch_size`, and `logical_global_id.z` encodes both the
// batch index and the Split-K index within that range.
// We decompose it as:
// split_index = logical_global_id.z % splits_per_batch
// batch = logical_global_id.z / splits_per_batch
shader.MainFunctionBody()
<< " let splits_per_batch = uniforms.splits_per_batch;\n"
<< " let split_index = i32(logical_global_id.z) % i32(splits_per_batch);\n"
<< " var kStart = kSplitK * split_index;\n"
<< " let batch = i32(logical_global_id.z) / i32(splits_per_batch);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
} else {
// With Split-K without batch (in Gemm), `logical_global_id.z` is exactly the Split-K index.
shader.MainFunctionBody()
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
}
} else {
shader.MainFunctionBody()
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
Expand Down
43 changes: 26 additions & 17 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/math/matmul.h"

#include <limits>

#include "core/common/inlined_containers.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
Expand Down Expand Up @@ -244,44 +247,49 @@ Status ComputeMatMul(ComputeContext* context,
const Tensor* bias = has_bias ? inputs[2] : nullptr;
bool use_bias_in_matmul = has_bias;
uint32_t split_dim_inner = 1;
uint32_t splits_per_batch = 1;

// Current Split-K implementation relies on atomic operations, which are not deterministic.
if (!context->KernelContext().GetUseDeterministicCompute()) {
const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, dim_a_outer, dim_b_outer, dim_inner, is_channels_last);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul requires vec4 packing.");

if (has_bias) {
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");
}

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp);
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp, narrow<uint32_t>(batch_size));
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));

// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
// `bias` again in `MatMulProgram`.
use_bias_in_matmul = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
// With Split-K, `dim_inner` will be split into multiple parts. `dispatch_z` encodes
// both the split-k index and the batch index: dispatch_z = splits_per_batch * batch_size.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
splits_per_batch = (dim_inner + split_dim_inner - 1) / split_dim_inner;
const uint64_t dispatch_z_u64 = static_cast<uint64_t>(batch_size) * static_cast<uint64_t>(splits_per_batch);
ORT_ENFORCE(dispatch_z_u64 <= static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()),
"dispatch_z exceeds uint32_t range: ", dispatch_z_u64);
dispatch_z = narrow<uint32_t>(dispatch_z_u64);

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
output.is_atomic = true;
}
}

const bool has_batch_dims = true; // MatMulProgram::GenerateShaderCode always passes &batch_dims.
MatMulProgram matmul_program{activation, use_bias_in_matmul, is_vec4, elements_per_thread, is_channels_last, split_dim_inner};
matmul_program
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner, has_batch_dims)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}, {splits_per_batch}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)
Expand All @@ -302,31 +310,32 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr
bool is_gemm,
float beta,
uint32_t output_components,
const TensorShape& output_shape) {
const TensorShape& output_shape,
uint32_t batch_size) {
const bool has_bias = bias != nullptr;
const bool bias_is_scalar = has_bias ? bias->Shape().Size() == 1 : false;

// Currently we only support GEMM and channels last format for MatMul with Split-K.
MatMulFillBiasOrZeroBeforeSplitKProgram program(is_gemm, has_bias, output_components, bias_is_scalar);

const uint32_t dim_a_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 2]);
const uint32_t dim_b_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 1]);

// Fill one value per invocation. Now we use default workgroup size (64) for this program.
const uint32_t total_outputs = dim_a_outer * dim_b_outer;
const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
// Fill one value per invocation across all batches.
const uint64_t total_outputs = static_cast<uint64_t>(batch_size) *
static_cast<uint64_t>(dim_a_outer) *
static_cast<uint64_t>(dim_b_outer);
const uint64_t dispatch_x_u64 = CeilDiv(total_outputs, static_cast<uint64_t>(WORKGROUP_SIZE));
ORT_ENFORCE(dispatch_x_u64 <= static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()),
"dispatch_x exceeds uint32_t range: ", dispatch_x_u64);
const uint32_t dispatch_x = narrow<uint32_t>(dispatch_x_u64);

// To reuse `MatMulWriteFnSourceForGemm()` or `MatMulWriteFnSourceForMatMul()` we need to set
// `dim_b_outer` in components when `output_shape` is in `vec4`, while use `output_shape` directly
// as the output shape.
const uint32_t dim_b_outer_components = narrow<uint32_t>(dim_b_outer * output_components);
program.CacheHint(is_gemm, has_bias, output_components, bias_is_scalar)
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, static_cast<int32_t>(output_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}, {batch_size}})
.SetDispatchGroupSize(dispatch_x);

if (has_bias) {
// We always use `c_components` as `output_components` in GEMM, and 4 in MatMul.
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), output_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(output_components)});
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webgpu/math/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr
bool is_gemm,
float beta,
uint32_t output_components,
const TensorShape& output_shape);
const TensorShape& output_shape,
uint32_t batch_size = 1);

class MatMul final : public WebGpuKernel {
public:
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
}

// Handle bias with `MatMulWriteFnSourceForGemm() or MatMulWriteFnSourceForMatMul()`.
// const uint32_t bias_components = output_components_;
if (is_gemm_) {
MatMulWriteFnSourceForGemm(shader, output, bias, bias_is_scalar_);
} else {
Expand All @@ -77,15 +76,18 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
shader.MainFunctionBody() << R"(
let output_id = i32(global_idx);

let batch_size = i32(uniforms.batch_size);
let dim_a_outer = i32(uniforms.dim_a_outer);
let dim_b_outer = i32(uniforms.dim_b_outer) / output_components;
if (output_id >= dim_a_outer * dim_b_outer) {
let elements_per_batch = dim_a_outer * dim_b_outer;
if (output_id >= batch_size * elements_per_batch) {
return;
Comment thread
Jiawei-Shao marked this conversation as resolved.
}

let output_row = output_id / dim_b_outer;
let output_col = output_id % dim_b_outer;
let output_batch = 0;
let output_batch = output_id / elements_per_batch;
let remaining = output_id % elements_per_batch;
let output_row = remaining / dim_b_outer;
let output_col = remaining % dim_b_outer;
let output_value = output_value_t();
mm_write(output_batch, output_row, output_col, output_value);
)";
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class MatMulProgram final : public Program<MatMulProgram> {
{"dim_inner", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32},
{"splits_per_batch", ProgramUniformVariableDataType::Uint32});

bool NeedSplitK() const;

Expand Down Expand Up @@ -58,7 +59,8 @@ class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program<MatMulFillB

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"beta", ProgramUniformVariableDataType::Float32});
{"beta", ProgramUniformVariableDataType::Float32},
{"batch_size", ProgramUniformVariableDataType::Uint32});

private:
bool is_gemm_ = false;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::v
std::transform(vec.begin(), vec.end(), std::ostream_iterator<std::string>(oss, ","), [](uint32_t i) { return std::to_string(i); });
return oss.str();
};
program.CacheHint(activation.ToString(), is_channels_last, stringify({inner_element_size, static_cast<uint32_t>(is_vec4 ? 1 : 0), fit_a_outer, fit_b_outer, fit_inner, tile_a_outer, tile_a_outer, tile_inner, static_cast<uint32_t>(components)}))

const bool has_batch_dims = false; // Conv2dMMProgram::GenerateShaderCode passes nullptr for batch_dims.
program.CacheHint(activation.ToString(), is_channels_last, has_batch_dims, stringify({inner_element_size, static_cast<uint32_t>(is_vec4 ? 1 : 0), fit_a_outer, fit_b_outer, fit_inner, tile_a_outer, tile_a_outer, tile_inner, static_cast<uint32_t>(components)}))
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, reduced_output_shape, components})
.SetDispatchGroupSize(dispatch[0], dispatch[1], dispatch[2])
.SetWorkgroupSize(workgroup_size[0], workgroup_size[1], workgroup_size[2])
Expand Down
37 changes: 21 additions & 16 deletions onnxruntime/core/providers/webgpu/webgpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,36 @@ SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) {
} else if (adapter_info.architecture == std::string_view{"xe-2lpg"} ||
adapter_info.architecture == std::string_view{"xe-2hpg"} ||
adapter_info.architecture == std::string_view{"gen-12hp"}) {
// Below thresholds are only verified on Intel discreate GPUs and Lunar Lake iGPUs.
// Below thresholds are only verified on Intel discrete GPUs and Lunar Lake iGPUs.
enable_split_k_ = true;

max_batch_size_ = 8;
split_dim_inner_ = 256;
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;

configs_per_dim_inner_range_.emplace_back(768, 52.0f);
configs_per_dim_inner_range_.emplace_back(2304, 35.0f);
configs_per_dim_inner_range_.emplace_back(3072, 21.5f);
configs_per_dim_inner_range_.emplace_back(4096, 16.0f);
configs_per_dim_inner_range_.emplace_back(768, 52.0);
configs_per_dim_inner_range_.emplace_back(2304, 35.0);
configs_per_dim_inner_range_.emplace_back(3072, 21.5);
configs_per_dim_inner_range_.emplace_back(4096, 16.0);
} else {
// Below are the default thresholds on newer Intel GPUs. These values are chosen on
// Intel "gen-12lp" GPU with 32EUs.
enable_split_k_ = true;

max_batch_size_ = 8;
split_dim_inner_ = 256;
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;

configs_per_dim_inner_range_.emplace_back(768, 20.0f);
configs_per_dim_inner_range_.emplace_back(1792, 13.0f);
configs_per_dim_inner_range_.emplace_back(3072, 8.0f);
configs_per_dim_inner_range_.emplace_back(4096, 6.0f);
configs_per_dim_inner_range_.emplace_back(768, 20.0);
configs_per_dim_inner_range_.emplace_back(1792, 13.0);
configs_per_dim_inner_range_.emplace_back(3072, 8.0);
configs_per_dim_inner_range_.emplace_back(4096, 6.0);
}
}
}

SplitKConfig::ConfigAtRange::ConfigAtRange(uint32_t max_dim_inner, float rate)
: max_dim_inner_with_rate(max_dim_inner), max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner(rate) {}
SplitKConfig::ConfigAtRange::ConfigAtRange(uint32_t max_dim_inner, double rate)
: max_dim_inner_with_rate(max_dim_inner), max_dim_a_outer_x_dim_b_outer_x_batch_size_divides_dim_inner(rate) {}

uint32_t SplitKConfig::GetMaxDimInnerWithSplitK() const {
assert(!configs_per_dim_inner_range_.empty());
Expand All @@ -87,7 +89,10 @@ bool SplitKConfig::UseSplitK(
// TODO: support the cases below.
use_split_k &= activation_kind == ActivationKind::None;
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;

// Larger batches increase parallelism on their own, so we temporarily set a batch size threshold
// for using Split-K.
use_split_k &= batch_size <= max_batch_size_;

// `is_channels_last` should only affect Split-K gating when bias is applied in the non-GEMM
// MatMul/Conv|MatMul path. For GEMM and for MatMul or Conv|MatMul without bias, we need to
Expand All @@ -97,19 +102,19 @@ bool SplitKConfig::UseSplitK(
use_split_k &= is_channels_last;

// Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
// `dim_inner)` as the metric to decide whether to use Split-K or not.
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer * batch_size)`
// and `dim_inner` as the metric to decide whether to use Split-K or not.
use_split_k &= dim_inner >= min_dim_inner_with_split_k_;
use_split_k &= dim_inner <= GetMaxDimInnerWithSplitK();

if (!use_split_k) {
return false;
}

const float rate = dim_a_outer * dim_b_outer * 1.0f / dim_inner;
const double rate = static_cast<double>(dim_a_outer) * static_cast<double>(dim_b_outer) * static_cast<double>(batch_size) / static_cast<double>(dim_inner);
Comment thread
Jiawei-Shao marked this conversation as resolved.
for (const auto& config_at_range : configs_per_dim_inner_range_) {
if (dim_inner <= config_at_range.max_dim_inner_with_rate) {
return rate <= config_at_range.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner;
return rate <= config_at_range.max_dim_a_outer_x_dim_b_outer_x_batch_size_divides_dim_inner;
}
}
return false;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ class SplitKConfig {
bool enable_split_k_ = false;
uint32_t split_dim_inner_ = 0;
uint32_t min_dim_inner_with_split_k_ = 0;
uint32_t max_batch_size_ = 0;

uint32_t GetMaxDimInnerWithSplitK() const;

struct ConfigAtRange {
ConfigAtRange(uint32_t max_dim_inner, float rate);
ConfigAtRange(uint32_t max_dim_inner, double rate);
uint32_t max_dim_inner_with_rate = 0;
float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner = 0.0f;
double max_dim_a_outer_x_dim_b_outer_x_batch_size_divides_dim_inner = 0.0;
};
Comment thread
Jiawei-Shao marked this conversation as resolved.
std::vector<ConfigAtRange> configs_per_dim_inner_range_;
};
Expand Down
Loading