Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {

const uint32_t components_a = a.NumComponents();
const uint32_t components_b = b.NumComponents() / 4; // b is stored as uint32 which includes 4 uint8.
constexpr uint32_t tile_size_k_vec = 16;
const uint32_t tile_size_k_vec = tile_size_k_vec_;
const uint32_t elements_in_value_b = components_b * (32 / nbits_);
const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
const uint32_t a_length_per_tile = tile_size_k / components_a;
Expand Down Expand Up @@ -301,13 +301,17 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
return context.RunProgram(program);
}

// Use tile_size_k_vec=32 by default for better K-dimension parallelism.
// Intel devices use 16 as they have different subgroup/cache characteristics.
const uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;

constexpr uint32_t workgroup_size = 128;
constexpr uint32_t tile_size = 8;
constexpr uint32_t kU32Components = 4;
uint32_t components_b_with_u32 = components_b * kU32Components;
uint32_t num_N_tile = (N + tile_size - 1) / tile_size;
uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32;
MatMulNBitsProgram program{tile_size, static_cast<uint32_t>(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights};
MatMulNBitsProgram program{tile_size, static_cast<uint32_t>(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights, tile_size_k_vec};
program.SetWorkgroupSize(workgroup_size);
program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count);
program
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class MatMulNBitsWideTileProgram final : public Program<MatMulNBitsWideTileProgr

class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
public:
MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights)
: Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, single_scale_weights_(single_scale_weights) {}
MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights, uint32_t tile_size_k_vec = 16)
: Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, single_scale_weights_(single_scale_weights), tile_size_k_vec_(tile_size_k_vec) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
Expand All @@ -65,6 +65,7 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool has_weight_idx_;
bool has_weight_idx_indirect_;
bool single_scale_weights_;
uint32_t tile_size_k_vec_;
};

class MatMulNBits final : public WebGpuKernel {
Expand Down
Loading