From 304383d34e8322503ed2b89402e495ebfdde032b Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 24 Mar 2026 19:56:33 +0800 Subject: [PATCH 1/2] webgpu: Increase MatMulNBits K-parallelism with tile_size_k_vec=32 Use tile_size_k_vec=32 (instead of 16) for MatMulNBits default kernel, doubling the number of threads working on K-dimension reduction per output row. This improves token generation throughput by ~3% on NVIDIA GPUs by better utilizing memory bandwidth. Intel devices retain tile_size_k_vec=16 due to different subgroup and cache characteristics. Changes: - matmul_nbits.h: Add tile_size_k_vec parameter (default 16) to MatMulNBitsProgram constructor. - matmul_nbits.cc: Select tile_size_k_vec=32 for non-Intel vendors, pass to program constructor. --- .../contrib_ops/webgpu/quantization/matmul_nbits.cc | 8 ++++++-- .../contrib_ops/webgpu/quantization/matmul_nbits.h | 5 +++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 765daa05ff25d..e6092c922f867 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -86,7 +86,7 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const uint32_t components_a = a.NumComponents(); const uint32_t components_b = b.NumComponents() / 4; // b is stored as uint32 which includes 4 uint8. - constexpr uint32_t tile_size_k_vec = 16; + const uint32_t tile_size_k_vec = tile_size_k_vec_; const uint32_t elements_in_value_b = components_b * (32 / nbits_); const uint32_t tile_size_k = tile_size_k_vec * elements_in_value_b; const uint32_t a_length_per_tile = tile_size_k / components_a; @@ -301,13 +301,17 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, return context.RunProgram(program); } + // Use tile_size_k_vec=32 by default for better K-dimension parallelism. + // Intel devices use 16 as they have different subgroup/cache characteristics. + const uint32_t tile_size_k_vec = (context.AdapterInfo().vendor == std::string_view{"intel"}) ? 16u : 32u; + constexpr uint32_t workgroup_size = 128; constexpr uint32_t tile_size = 8; constexpr uint32_t kU32Components = 4; uint32_t components_b_with_u32 = components_b * kU32Components; uint32_t num_N_tile = (N + tile_size - 1) / tile_size; uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; - MatMulNBitsProgram program{tile_size, static_cast(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights}; + MatMulNBitsProgram program{tile_size, static_cast(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights, tile_size_k_vec}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); program diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 70ddf2f818627..295a0fb90dd2a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -41,8 +41,8 @@ class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights) - : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, single_scale_weights_(single_scale_weights) {} + MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights, uint32_t tile_size_k_vec = 16) + : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, single_scale_weights_(single_scale_weights), tile_size_k_vec_(tile_size_k_vec) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -65,6 +65,7 @@ class MatMulNBitsProgram final : public Program { bool has_weight_idx_; bool has_weight_idx_indirect_; bool single_scale_weights_; + uint32_t tile_size_k_vec_; }; class MatMulNBits final : public WebGpuKernel { From 8cd8454ce2c1a23bd65573f718bd7ff124df0163 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 30 Mar 2026 14:35:28 +0800 Subject: [PATCH 2/2] webgpu: Add tile_size_k_vec to MatMulNBits CacheHint tile_size_k_vec varies by GPU vendor (32 for NVIDIA, 16 for Intel). Without it in CacheHint, a cached shader compiled with one value could be incorrectly reused when tile_size_k_vec changes, producing wrong results or suboptimal performance. --- onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index e6092c922f867..7d99256682b6b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -330,7 +330,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, {num_N_tile}, {batch_count}, {weight_index}}) - .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect); + .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, tile_size_k_vec); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); }