Skip to content
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
* @return Status indicating whether the operation was successful or if an error occurred.
*
* @note Special optimizations are considered:
* - Subgroup matrix multiplication for eligible Apple/Intel GPUs.
* - Subgroup matrix multiplication for GPUs with supported configs.
* - DP4A-based multiplication on FP32-only GPUs for specific dimensions and conditions.
* - A wide tile program is used when block size, component count, and other criteria are met.
* - Otherwise, a default matmul program is used.
Expand Down Expand Up @@ -227,8 +227,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,

#if !defined(__wasm__)
int32_t subgroup_matrix_config_index = -1;
// apple|intel - Experimental dawn support for subgroup matrix matmul.
if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) &&
// Experimental dawn support for subgroup matrix matmul (vendor-agnostic).
if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) &&
CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast<uint32_t>(nbits), y->DataType() == DataTypeImpl::GetType<MLFloat16>(), subgroup_matrix_config_index)) {
return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast<uint32_t>(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect);
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

#if !defined(__wasm__)

#include <string>

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/compute_context.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/shader_helper.h"
Expand All @@ -21,11 +18,10 @@ using namespace onnxruntime::webgpu;

class SubgroupMatrixMatMulNBitsProgram final : public Program<SubgroupMatrixMatMulNBitsProgram> {
public:
SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect)
SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect)
: Program{"SubgroupMatrixMatMulNBits"},
nbits_(nbits),
config_index_(config_index),
vendor_(vendor),
has_zero_points_(has_zero_points),
has_bias_(has_bias),
has_weight_idx_{has_weight_idx},
Expand All @@ -41,7 +37,6 @@ class SubgroupMatrixMatMulNBitsProgram final : public Program<SubgroupMatrixMatM
private:
uint32_t nbits_;
int32_t config_index_;
std::string vendor_;
bool has_zero_points_;
bool has_bias_;
bool has_weight_idx_;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Intel SubgroupMatrix prepack kernel
// SubgroupMatrix prepack kernel
// Rearranges input matrix A(MxK) so that each subgroup matrix (sg_mat_m x sg_mat_k)
// has its elements laid out contiguously in memory for subgroupMatrixLoad.
//
// The prepack buffer is padded to the workgroup tile size (which may exceed M).
// Fully OOB blocks (row_base >= M) are skipped entirely, and partial blocks
// only copy in-bounds rows. Padding corresponding to rows >= M may therefore
// remain uninitialized, so downstream shaders must not rely on it unless they
// handle edge tiles explicitly.

#param sg_mat_k
#param sg_mat_m
Expand All @@ -14,11 +20,24 @@ const kSgMatK: u32 = u32(sg_mat_k);
$MAIN {
let M = uniforms.M;
let K = uniforms.K;
let in_offset = workgroup_id.x * kSgMatM * K + workgroup_id.y * kSgMatK;
let row_base = workgroup_id.x * kSgMatM;
let in_offset = row_base * K + workgroup_id.y * kSgMatK;
let out_offset = (workgroup_id.x * K / kSgMatK + workgroup_id.y) * kSgMatM * kSgMatK;

// Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride
var mat: subgroup_matrix_left<f16, sg_mat_k, sg_mat_m> =
subgroupMatrixLoad<subgroup_matrix_left<f16, sg_mat_k, sg_mat_m>>(&input_a, in_offset, false, uniforms.K);
subgroupMatrixStore(&output_a, out_offset, mat, false, kSgMatK);
if (row_base + kSgMatM <= M) {
// All rows in this block are within bounds - use fast subgroupMatrixLoad.
var mat: subgroup_matrix_left<f16, sg_mat_k, sg_mat_m> =
subgroupMatrixLoad<subgroup_matrix_left<f16, sg_mat_k, sg_mat_m>>(&input_a, in_offset, false, K);
subgroupMatrixStore(&output_a, out_offset, mat, false, kSgMatK);
} else if (row_base < M) {
// Partial block: some rows are OOB. Use scalar copy for in-bounds rows only.
for (var r: u32 = local_idx; r < kSgMatM * kSgMatK; r += workgroup_size_x) {
let row = r / kSgMatK;
let col = r % kSgMatK;
if (row_base + row < M) {
output_a[out_offset + r] = input_a[in_offset + row * K + col];
}
}
}
// Fully OOB blocks (row_base >= M): skip entirely.
} // MAIN
4 changes: 4 additions & 0 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ TEST(MatMulNBits, Float16_4b_Accuracy0) {
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1024, 128, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 32, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1234, 16, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 256, 128, 32, 0>();
TestMatMulNBitsTyped<MLFloat16, 100, 192, 128, 32, 0>();
}

TEST(MatMulNBits, Float16_4b_Accuracy4) {
Expand Down Expand Up @@ -495,6 +497,8 @@ TEST(MatMulNBits, Float16_4b_Accuracy4) {
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 32, 4>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 93, 128, 4>();
TestMatMulNBitsTyped<MLFloat16, 100, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<MLFloat16, 100, 256, 128, 32, 4>();
TestMatMulNBitsTyped<MLFloat16, 100, 192, 128, 32, 4>();

// See PR #27412 for details on the following test case,
// which is added to cover a specific failure case in the past.
Expand Down
Loading