Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_);
}
if (is_vec4_) {
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
} else {
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_));
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, transA_, transB_, alpha_, need_handle_matmul_));
}

const ShaderVariableHelper* c = nullptr;
Expand Down
28 changes: 21 additions & 7 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,27 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader,
// atomic built-in functions in `HandleMatMulWithSplitK()`.
shader.MainFunctionBody()
<< "const kSplitK = " << split_dim_inner << ";\n"
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n"
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"

// When Split-K is used, `batch` should always be 0 and `logical_global_id.z` is used to indicate
// the index of split-k instead of batch.
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
<< " let num_tiles = (kSplitK - 1) / tileInner + 1;\n";
if (nullptr != batch_dims) {
Comment thread
Jiawei-Shao marked this conversation as resolved.
// With Split-K and batch (in MatMul and Conv2D|MatMul), `dispatch_z` is
// `splits_per_batch * batch_size`, and `logical_global_id.z` encodes both the
// batch index and the Split-K index within that range.
// We decompose it as:
// split_index = logical_global_id.z % splits_per_batch
// batch = logical_global_id.z / splits_per_batch
shader.MainFunctionBody()
<< " let splits_per_batch = uniforms.splits_per_batch;\n"
<< " let split_index = i32(logical_global_id.z) % i32(splits_per_batch);\n"
<< " var kStart = kSplitK * split_index;\n"
<< " let batch = i32(logical_global_id.z) / i32(splits_per_batch);\n"
<< (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "");
} else {
// With Split-K without batch (in Gemm), `logical_global_id.z` is exactly the Split-K index.
shader.MainFunctionBody()
<< " var kStart = kSplitK * i32(logical_global_id.z);\n"
<< " let batch = 0;\n"
<< " let batchIndices = 0u;\n";
}
} else {
shader.MainFunctionBody()
<< " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n"
Expand Down
27 changes: 12 additions & 15 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,31 +244,32 @@ Status ComputeMatMul(ComputeContext* context,
const Tensor* bias = has_bias ? inputs[2] : nullptr;
bool use_bias_in_matmul = has_bias;
uint32_t split_dim_inner = 1;
uint32_t splits_per_batch = 1;

// Current Split-K implementation relies on atomic operations, which are not deterministic.
if (!context->KernelContext().GetUseDeterministicCompute()) {
const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, dim_a_outer, dim_b_outer, dim_inner, is_channels_last);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul requires vec4 packing.");

if (has_bias) {
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");
}

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp);
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp, narrow<uint32_t>(batch_size));
ORT_RETURN_IF_ERROR(context->RunProgram(fill_bias_program));

// `bias` has been handled in the execution of `fill_bias_program` so we don't need to set
// `bias` again in `MatMulProgram`.
use_bias_in_matmul = false;

// With Split-K, `dim_inner` will be split into multiple parts and `dispatch_z` will be the
// number of splits along `dim_inner`.
// With Split-K, `dim_inner` will be split into multiple parts. `dispatch_z` encodes
// both the split-k index and the batch index: dispatch_z = splits_per_batch * batch_size.
split_dim_inner = split_k_config.GetSplitDimInner();
dispatch_z = (dim_inner + split_dim_inner - 1) / split_dim_inner;
splits_per_batch = (dim_inner + split_dim_inner - 1) / split_dim_inner;
dispatch_z = narrow<uint32_t>(batch_size) * splits_per_batch;
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated

// The output should be declared in atomic types in `MatMulProgram` for the use of atomic
// built-in functions.
Expand All @@ -281,7 +282,7 @@ Status ComputeMatMul(ComputeContext* context,
.CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4), components, is_channels_last, split_dim_inner)
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}, {dispatch_x}, {dispatch_y}, {dispatch_z}, {splits_per_batch}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z)
Expand All @@ -302,31 +303,27 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr
bool is_gemm,
float beta,
uint32_t output_components,
const TensorShape& output_shape) {
const TensorShape& output_shape,
uint32_t batch_size) {
const bool has_bias = bias != nullptr;
const bool bias_is_scalar = has_bias ? bias->Shape().Size() == 1 : false;

// Currently we only support GEMM and channels last format for MatMul with Split-K.
MatMulFillBiasOrZeroBeforeSplitKProgram program(is_gemm, has_bias, output_components, bias_is_scalar);

const uint32_t dim_a_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 2]);
const uint32_t dim_b_outer = narrow<uint32_t>(output_shape[output_shape.NumDimensions() - 1]);

// Fill one value per invocation. Now we use default workgroup size (64) for this program.
const uint32_t total_outputs = dim_a_outer * dim_b_outer;
// Fill one value per invocation across all batches.
const uint32_t total_outputs = batch_size * dim_a_outer * dim_b_outer;
const uint32_t dispatch_x = (total_outputs + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated

// To reuse `MatMulWriteFnSourceForGemm()` or `MatMulWriteFnSourceForMatMul()` we need to set
// `dim_b_outer` in components when `output_shape` is in `vec4`, while use `output_shape` directly
// as the output shape.
const uint32_t dim_b_outer_components = narrow<uint32_t>(dim_b_outer * output_components);
program.CacheHint(is_gemm, has_bias, output_components, bias_is_scalar)
.AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, output_shape, static_cast<int32_t>(output_components)})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer_components}, {beta}, {batch_size}})
.SetDispatchGroupSize(dispatch_x);

if (has_bias) {
// We always use `c_components` as `output_components` in GEMM, and 4 in MatMul.
const TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), output_components);
program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, static_cast<int32_t>(output_components)});
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/webgpu/math/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ MatMulFillBiasOrZeroBeforeSplitKProgram CreateMatMulFillBiasOrZeroBeforeSplitKPr
bool is_gemm,
float beta,
uint32_t output_components,
const TensorShape& output_shape);
const TensorShape& output_shape,
uint32_t batch_size = 1);

class MatMul final : public WebGpuKernel {
public:
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
}

// Handle bias with `MatMulWriteFnSourceForGemm() or MatMulWriteFnSourceForMatMul()`.
// const uint32_t bias_components = output_components_;
if (is_gemm_) {
MatMulWriteFnSourceForGemm(shader, output, bias, bias_is_scalar_);
} else {
Expand All @@ -77,15 +76,18 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
shader.MainFunctionBody() << R"(
let output_id = i32(global_idx);

let batch_size = i32(uniforms.batch_size);
let dim_a_outer = i32(uniforms.dim_a_outer);
let dim_b_outer = i32(uniforms.dim_b_outer) / output_components;
if (output_id >= dim_a_outer * dim_b_outer) {
let elements_per_batch = dim_a_outer * dim_b_outer;
if (output_id >= batch_size * elements_per_batch) {
return;
Comment thread
Jiawei-Shao marked this conversation as resolved.
}

let output_row = output_id / dim_b_outer;
let output_col = output_id % dim_b_outer;
let output_batch = 0;
let output_batch = output_id / elements_per_batch;
let remaining = output_id % elements_per_batch;
let output_row = remaining / dim_b_outer;
let output_col = remaining % dim_b_outer;
let output_value = output_value_t();
mm_write(output_batch, output_row, output_col, output_value);
)";
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class MatMulProgram final : public Program<MatMulProgram> {
{"dim_inner", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_x", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_y", ProgramUniformVariableDataType::Uint32},
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32});
{"logical_dispatch_z", ProgramUniformVariableDataType::Uint32},
{"splits_per_batch", ProgramUniformVariableDataType::Uint32});

bool NeedSplitK() const;

Expand Down Expand Up @@ -58,7 +59,8 @@ class MatMulFillBiasOrZeroBeforeSplitKProgram final : public Program<MatMulFillB

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"beta", ProgramUniformVariableDataType::Float32});
{"beta", ProgramUniformVariableDataType::Float32},
{"batch_size", ProgramUniformVariableDataType::Uint32});

private:
bool is_gemm_ = false;
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) {
// Below thresholds are only verified on Intel discreate GPUs and Lunar Lake iGPUs.
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
enable_split_k_ = true;

max_batch_size_ = 8;
split_dim_inner_ = 256;
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;

Expand All @@ -51,6 +52,7 @@ SplitKConfig::SplitKConfig(const wgpu::AdapterInfo& adapter_info) {
// Intel "gen-12lp" GPU with 32EUs.
enable_split_k_ = true;

max_batch_size_ = 8;
split_dim_inner_ = 256;
min_dim_inner_with_split_k_ = split_dim_inner_ * 2;

Expand Down Expand Up @@ -87,7 +89,10 @@ bool SplitKConfig::UseSplitK(
// TODO: support the cases below.
use_split_k &= activation_kind == ActivationKind::None;
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;

// Larger batches increase parallelism on their own, so we temporarily set a batch size threshold
// for using Split-K.
use_split_k &= batch_size <= max_batch_size_;

// `is_channels_last` should only affect Split-K gating when bias is applied in the non-GEMM
// MatMul/Conv|MatMul path. For GEMM and for MatMul or Conv|MatMul without bias, we need to
Expand All @@ -106,7 +111,7 @@ bool SplitKConfig::UseSplitK(
return false;
}

const float rate = dim_a_outer * dim_b_outer * 1.0f / dim_inner;
const float rate = dim_a_outer * dim_b_outer * batch_size * 1.0f / dim_inner;
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
for (const auto& config_at_range : configs_per_dim_inner_range_) {
if (dim_inner <= config_at_range.max_dim_inner_with_rate) {
return rate <= config_at_range.max_dim_a_outer_multiplies_dim_b_outer_divides_dim_inner;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class SplitKConfig {
bool enable_split_k_ = false;
uint32_t split_dim_inner_ = 0;
uint32_t min_dim_inner_with_split_k_ = 0;
uint32_t max_batch_size_ = 0;

uint32_t GetMaxDimInnerWithSplitK() const;

Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/session/ort_version_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@

namespace onnxruntime::version_check {

// A simple consteval-friendly result type for ParseUint.
// A simple constexpr-friendly result type for ParseUint.
// std::optional triggers an internal compiler error in MSVC 14.44 when used with consteval.
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
struct ParseUintResult {
uint32_t value;
bool has_value;

consteval bool operator==(uint32_t other) const { return has_value && value == other; }
consteval bool operator!=(uint32_t other) const { return !(*this == other); }
constexpr bool operator==(uint32_t other) const { return has_value && value == other; }
constexpr bool operator!=(uint32_t other) const { return !(*this == other); }
};

inline consteval ParseUintResult ParseUintNone() { return {0, false}; }
inline constexpr ParseUintResult ParseUintNone() { return {0, false}; }

// Parse a non-negative integer from a string_view without leading zeros.
// Returns a result with has_value == false on failure (empty, leading zero, non-digit, or overflow).
consteval ParseUintResult ParseUint(std::string_view str) {
constexpr ParseUintResult ParseUint(std::string_view str) {
if (str.empty()) return ParseUintNone();
// Leading zeros are not allowed (except "0" itself).
if (str.size() > 1 && str[0] == '0') return ParseUintNone();
Expand All @@ -42,7 +42,7 @@ consteval ParseUintResult ParseUint(std::string_view str) {
// - Major version is 1
// - Y and Z are non-negative integers without leading zeros
// - Y (minor version) must equal expected_api_version (defaults to ORT_API_VERSION)
consteval bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) {
constexpr bool IsOrtVersionValid(std::string_view version, uint32_t expected_api_version = ORT_API_VERSION) {
size_t first_dot = version.find('.');
if (first_dot == std::string_view::npos) return false;
size_t second_dot = version.find('.', first_dot + 1);
Expand Down
56 changes: 56 additions & 0 deletions onnxruntime/test/providers/cpu/math/matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,62 @@ TEST(MathOpTest, MatMulSharedPrepackedWeights) {
}
}

// Test MatMul with batch_size > 1 that exercises the Split-K path.
// Split-K is triggered when dim_inner is large relative to dim_a_outer * dim_b_outer,
// is_vec4 is true, and the GPU supports it. This test validates correctness when
// batch_size > 1 with dimensions that would trigger Split-K on supported hardware.
TEST(MathOpTest, MatMulBatchedSplitK) {
// Dimensions chosen so dim_inner is large (triggers Split-K) and vec4-compatible.
// batch=2, M=4, K=768, N=64
constexpr int64_t batch = 2;
constexpr int64_t M = 4;
constexpr int64_t K = 768;
constexpr int64_t N = 64;

std::vector<int64_t> A_shape = {batch, M, K};
std::vector<int64_t> B_shape = {batch, K, N};
std::vector<int64_t> Y_shape = {batch, M, N};

// Generate sequential data so the expected output is deterministic.
int64_t a_size = batch * M * K;
int64_t b_size = batch * K * N;
std::vector<float> A_data(a_size);
std::vector<float> B_data(b_size);

// Use small values to avoid fp32 overflow.
for (int64_t i = 0; i < a_size; ++i) {
A_data[i] = static_cast<float>((i % 11) - 5) * 0.01f;
}
for (int64_t i = 0; i < b_size; ++i) {
B_data[i] = static_cast<float>((i % 13) - 6) * 0.01f;
}

// Compute expected output on CPU.
std::vector<float> expected(batch * M * N, 0.0f);
for (int64_t b_idx = 0; b_idx < batch; ++b_idx) {
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float sum = 0.0f;
for (int64_t k = 0; k < K; ++k) {
float a_val = A_data[b_idx * M * K + m * K + k];
float b_val = B_data[b_idx * K * N + k * N + n];
sum += a_val * b_val;
}
expected[b_idx * M * N + m * N + n] = sum;
}
}
}

OpTester test("MatMul", 13);
test.AddInput<float>("A", A_shape, A_data);
test.AddInput<float>("B", B_shape, B_data);
test.AddOutput<float>("Y", Y_shape, expected);

// Exclude providers that don't support this configuration.
test.ConfigExcludeEps({kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider})
Comment thread
Jiawei-Shao marked this conversation as resolved.
.RunWithConfig();
}

#endif

} // namespace test
Expand Down
Loading
Loading