Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_);
}
if (is_vec4_) {
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
} else {
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_));
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, transA_, transB_, alpha_, need_handle_matmul_));
}

const ShaderVariableHelper* c = nullptr;
Expand Down
28 changes: 21 additions & 7 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,27 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
// atomic built-in functions in `HandleMatMulWithSplitK()`.
shader.MainFunctionBody()
<< "const kSplitK = " << split_dim_inner << ";\n"
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"

// When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate
// the index of split-k instead of batch.
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n";
if (nullptr != batch_dims) {
Comment thread
Jiawei-Shao marked this conversation as resolved.
// With Split-K and batch (in MatMul and Conv2D|MatMul), `dispatch_z` is
// `splits_per_batch * batch_size`, and `logical_global_id.z` encodes both the
// batch index and the Split-K index within that range.
// We decompose it as:
// split_index = logical_global_id.z % splits_per_batch
// batch = logical_global_id.z / splits_per_batch
shader.MainFunctionBody()
<< " let splits_per_batch = uniforms.splits_per_batch;\n"
<< " let split_index = i32(logical_global_id.z) % i32(splits_per_batch);\n"
<< " var kStart = kSplitK * split_index;\n"
<< " let batch = i32(logical_global_id.z) / i32(splits_per_batch);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
} else {
// With Split-K without batch (in Gemm), `logical_global_id.z` is exactly the Split-K index.
shader.MainFunctionBody()
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
}
} else {
shader.MainFunctionBody()
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
Expand Down
37 changes: 21 additions & 16 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/math/matmul.h"

#include <limits>

#include "core/common/inlined_containers.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
Expand Down Expand Up @@ -244,31 +247,32 @@ Status ComputeMatMul(ComputeContext* context,
const Tensor* bias = has_bias ? inputs[2] : nullptr;
bool use_bias_in_matmul = has_bias;
uint32_t split_dim_inner = 1;
uint32_t splits_per_batch = 1;

// Current Split-K implementation relies on atomic operations, which are not deterministic.
if (!context->KernelContext().GetUseDeterministicCompute()) {
const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, dim_a_outer, dim_b_outer, dim_inner, is_channels_last);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul requires vec4 packing.");

if (has_bias) {
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");
}

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp);
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp, narrow<uint32_t>(batch_size));
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));

// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
// `bias` again in `MatMulProgram`.
use_bias_in_matmul = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
// With Split-K, `dim_inner` will be split into multiple parts. `dispatch_z` encodes
// both the split-k index and the batch index: dispatch_z = splits_per_batch * batch_size.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
splits_per_batch = (dim_inner + split_dim_inner - 1) / split_dim_inner;
dispatch_z = narrow<uint32_t>(batch_size) * splits_per_batch;
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
Expand All @@ -281,7 +285,7 @@ Status ComputeMatMul(ComputeContext* context,
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}, {splits_per_batch}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)
Expand All @@ -302,31 +306,32 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr
bool is_gemm,
float beta,
uint32_t output_components,
const TensorShape& output_shape) {
const TensorShape& output_shape,
uint32_t batch_size) {
const bool has_bias = bias != nullptr;
const bool bias_is_scalar = has_bias ? bias->Shape().Size() == 1 : false;

// Currently we only support GEMM and channels last format for MatMul with Split-K.
MatMulFillBiasOrZeroBeforeSplitKProgram program(is_gemm, has_bias, output_components, bias_is_scalar);

const uint32_t dim_a_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 2]);
const uint32_t dim_b_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 1]);

// Fill one value per invocation. Now we use default workgroup size (64) for this program.
const uint32_t total_outputs = dim_a_outer * dim_b_outer;
const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
// Fill one value per invocation across all batches.
const uint64_t total_outputs = static_cast<uint64_t>(batch_size) *
static_cast<uint64_t>(dim_a_outer) *
static_cast<uint64_t>(dim_b_outer);
const uint64_t dispatch_x_u64 = CeilDiv(total_outputs, static_cast<uint64_t>(WORKGROUP_SIZE));
ORT_ENFORCE(dispatch_x_u64 <= static_cast<uint64_t>(std::numeric_limits<uint32_t>::max()),
"dispatch_x exceeds uint32_t range: ", dispatch_x_u64);
const uint32_t dispatch_x = narrow<uint32_t>(dispatch_x_u64);

// To reuse `MatMulWriteFnSourceForGemm()` or `MatMulWriteFnSourceForMatMul()` we need to set
// `dim_b_outer` in components when `output_shape` is in `vec4`, while use `output_shape` directly
// as the output shape.
const uint32_t dim_b_outer_components = narrow<uint32_t>(dim_b_outer * output_components);
program.CacheHint(is_gemm, has_bias, output_components, bias_is_scalar)
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, static_cast<int32_t>(output_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}, {batch_size}})
.SetDispatchGroupSize(dispatch_x);

if (has_bias) {
// We always use `c_components` as `output_components` in GEMM, and 4 in MatMul.
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), output_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(output_components)});
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webgpu/math/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr
bool is_gemm,
float beta,
uint32_t output_components,
const TensorShape& output_shape);
const TensorShape& output_shape,
uint32_t batch_size = 1);

class MatMul final : public WebGpuKernel {
public:
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
}

// Handle bias with `MatMulWriteFnSourceForGemm() or MatMulWriteFnSourceForMatMul()`.
// const uint32_t bias_components = output_components_;
if (is_gemm_) {
MatMulWriteFnSourceForGemm(shader, output, bias, bias_is_scalar_);
} else {
Expand All @@ -77,15 +76,18 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
shader.MainFunctionBody() << R"(
let output_id = i32(global_idx);

let batch_size = i32(uniforms.batch_size);
let dim_a_outer = i32(uniforms.dim_a_outer);
let dim_b_outer = i32(uniforms.dim_b_outer) / output_components;
if (output_id >= dim_a_outer * dim_b_outer) {
let elements_per_batch = dim_a_outer * dim_b_outer;
if (output_id >= batch_size * elements_per_batch) {
return;
Comment thread
Jiawei-Shao marked this conversation as resolved.
}

let output_row = output_id / dim_b_outer;
let output_col = output_id % dim_b_outer;
let output_batch = 0;
let output_batch = output_id / elements_per_batch;
let remaining = output_id % elements_per_batch;
let output_row = remaining / dim_b_outer;
let output_col = remaining % dim_b_outer;
let output_value = output_value_t();
mm_write(output_batch, output_row, output_col, output_value);
)";
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class MatMulProgram final : public Program<MatMulProgram> {
{"dim_inner", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32},
{"splits_per_batch", ProgramUniformVariableDataType::Uint32});

bool NeedSplitK() const;

Expand Down Expand Up @@ -58,7 +59,8 @@ class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program<MatMulFillB

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"beta", ProgramUniformVariableDataType::Float32});
{"beta", ProgramUniformVariableDataType::Float32},
{"batch_size", ProgramUniformVariableDataType::Uint32});

private:
bool is_gemm_ = false;
Expand Down
35 changes: 20 additions & 15 deletions onnxruntime/core/providers/webgpu/webgpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,33 @@ SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) {
// Below thresholds are only verified on Intel discreate GPUs and Lunar Lake iGPUs.
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
enable_split_k_ = true;

max_batch_size_ = 8;
split_dim_inner_ = 256;
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;

configs_per_dim_inner_range_.emplace_back(768, 52.0f);
configs_per_dim_inner_range_.emplace_back(2304, 35.0f);
configs_per_dim_inner_range_.emplace_back(3072, 21.5f);
configs_per_dim_inner_range_.emplace_back(4096, 16.0f);
configs_per_dim_inner_range_.emplace_back(768, 52.0);
configs_per_dim_inner_range_.emplace_back(2304, 35.0);
configs_per_dim_inner_range_.emplace_back(3072, 21.5);
configs_per_dim_inner_range_.emplace_back(4096, 16.0);
} else {
// Below are the default thresholds on newer Intel GPUs. These values are chosen on
// Intel "gen-12lp" GPU with 32EUs.
enable_split_k_ = true;

max_batch_size_ = 8;
split_dim_inner_ = 256;
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;

configs_per_dim_inner_range_.emplace_back(768, 20.0f);
configs_per_dim_inner_range_.emplace_back(1792, 13.0f);
configs_per_dim_inner_range_.emplace_back(3072, 8.0f);
configs_per_dim_inner_range_.emplace_back(4096, 6.0f);
configs_per_dim_inner_range_.emplace_back(768, 20.0);
configs_per_dim_inner_range_.emplace_back(1792, 13.0);
configs_per_dim_inner_range_.emplace_back(3072, 8.0);
configs_per_dim_inner_range_.emplace_back(4096, 6.0);
}
}
}

SplitKConfig::ConfigAtRange::ConfigAtRange(uint32_t max_dim_inner, float rate)
: max_dim_inner_with_rate(max_dim_inner), max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner(rate) {}
SplitKConfig::ConfigAtRange::ConfigAtRange(uint32_t max_dim_inner, double rate)
: max_dim_inner_with_rate(max_dim_inner), max_dim_a_outer_x_dim_b_outer_x_batch_size_divides_dim_inner(rate) {}

uint32_t SplitKConfig::GetMaxDimInnerWithSplitK() const {
assert(!configs_per_dim_inner_range_.empty());
Expand All @@ -87,7 +89,10 @@ bool SplitKConfig::UseSplitK(
// TODO: support the cases below.
use_split_k &= activation_kind == ActivationKind::None;
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;

// Larger batches increase parallelism on their own, so we temporarily set a batch size threshold
// for using Split-K.
use_split_k &= batch_size <= max_batch_size_;

// `is_channels_last` should only affect Split-K gating when bias is applied in the non-GEMM
// MatMul/Conv|MatMul path. For GEMM and for MatMul or Conv|MatMul without bias, we need to
Expand All @@ -97,19 +102,19 @@ bool SplitKConfig::UseSplitK(
use_split_k &= is_channels_last;

// Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
// `dim_inner)` as the metric to decide whether to use Split-K or not.
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer * batch_size)`
// and `dim_inner` as the metric to decide whether to use Split-K or not.
use_split_k &= dim_inner >= min_dim_inner_with_split_k_;
use_split_k &= dim_inner <= GetMaxDimInnerWithSplitK();

if (!use_split_k) {
return false;
}

const float rate = dim_a_outer * dim_b_outer * 1.0f / dim_inner;
const double rate = static_cast<double>(dim_a_outer) * static_cast<double>(dim_b_outer) * static_cast<double>(batch_size) / static_cast<double>(dim_inner);
Comment thread
Jiawei-Shao marked this conversation as resolved.
for (const auto& config_at_range : configs_per_dim_inner_range_) {
if (dim_inner <= config_at_range.max_dim_inner_with_rate) {
return rate <= config_at_range.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner;
return rate <= config_at_range.max_dim_a_outer_x_dim_b_outer_x_batch_size_divides_dim_inner;
}
}
return false;
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ class SplitKConfig {
bool enable_split_k_ = false;
uint32_t split_dim_inner_ = 0;
uint32_t min_dim_inner_with_split_k_ = 0;
uint32_t max_batch_size_ = 0;

uint32_t GetMaxDimInnerWithSplitK() const;

struct ConfigAtRange {
ConfigAtRange(uint32_t max_dim_inner, float rate);
ConfigAtRange(uint32_t max_dim_inner, double rate);
uint32_t max_dim_inner_with_rate = 0;
float max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner = 0.0f;
double max_dim_a_outer_x_dim_b_outer_x_batch_size_divides_dim_inner = 0.0;
};
Comment thread
Jiawei-Shao marked this conversation as resolved.
std::vector<ConfigAtRange> configs_per_dim_inner_range_;
};
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/session/ort_version_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@

namespace onnxruntime::version_check {

// A simple consteval-friendly result type for ParseUint.
// A simple constexpr-friendly result type for ParseUint.
// std::optional triggers an internal compiler error in MSVC 14.44 when used with consteval.
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
struct ParseUintResult {
uint32_t value;
bool has_value;

consteval bool operator==(uint32_t other) const { return has_value && value == other; }
consteval bool operator!=(uint32_t other) const { return !(*this == other); }
constexpr bool operator==(uint32_t other) const { return has_value && value == other; }
constexpr bool operator!=(uint32_t other) const { return !(*this == other); }
};

inline consteval ParseUintResult ParseUintNone() { return {0, false}; }
inline constexpr ParseUintResult ParseUintNone() { return {0, false}; }

// Parse a non-negative integer from a string_view without leading zeros.
// Returns a result with has_value == false on failure (empty, leading zero, non-digit, or overflow).
consteval ParseUintResult ParseUint(std::string_view str) {
constexpr ParseUintResult ParseUint(std::string_view str) {
if (str.empty()) return ParseUintNone();
// Leading zeros are not allowed (except "0" itself).
if (str.size() > 1 && str[0] == '0') return ParseUintNone();
Expand All @@ -42,7 +42,7 @@ consteval ParseUintResult ParseUint(std::string_view str) {
// - Major version is 1
// - Y and Z are non-negative integers without leading zeros
// - Y (minor version) must equal expected_api_version (defaults to ORT_API_VERSION)
consteval bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) {
constexpr bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) {
size_t first_dot = version.find('.');
if (first_dot == std::string_view::npos) return false;
size_t second_dot = version.find('.', first_dot + 1);
Expand Down
Loading
Loading