Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {

const uint32_t components_a = a.NumComponents();
const uint32_t components_b = b.NumComponents() / 4; // b is stored as uint32 which includes 4 uint8.
constexpr uint32_t tile_size_k_vec = 16;
const uint32_t tile_size_k_vec = tile_size_k_vec_;
const uint32_t elements_in_value_b = components_b * (32 / nbits_);
const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b;
const uint32_t a_length_per_tile = tile_size_k / components_a;
Expand Down Expand Up @@ -301,13 +301,17 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
return context.RunProgram(program);
}

// Use tile_size_k_vec=32 by default for better K-dimension parallelism.
// Intel devices use 16 as they have different subgroup/cache characteristics.
const uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u;

constexpr uint32_t workgroup_size = 128;
constexpr uint32_t tile_size = 8;
constexpr uint32_t kU32Components = 4;
uint32_t components_b_with_u32 = components_b * kU32Components;
uint32_t num_N_tile = (N + tile_size - 1) / tile_size;
uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32;
MatMulNBitsProgram program{tile_size, static_cast<uint32_t>(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights};
MatMulNBitsProgram program{tile_size, static_cast<uint32_t>(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights, tile_size_k_vec};
program.SetWorkgroupSize(workgroup_size);
program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count);
program
Expand All @@ -326,7 +330,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
{num_N_tile},
{batch_count},
{weight_index}})
.CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect);
.CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, tile_size_k_vec);
if (has_zero_points) {
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class MatMulNBitsWideTileProgram final : public Program<MatMulNBitsWideTileProgr

class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
public:
MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights)
: Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, single_scale_weights_(single_scale_weights) {}
MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights, uint32_t tile_size_k_vec = 16)
: Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, single_scale_weights_(single_scale_weights), tile_size_k_vec_(tile_size_k_vec) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
Expand All @@ -65,6 +65,7 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
bool has_weight_idx_;
bool has_weight_idx_indirect_;
bool single_scale_weights_;
uint32_t tile_size_k_vec_;
};

class MatMulNBits final : public WebGpuKernel {
Expand Down
Loading