From 87ed70e6bd1a66a2c9dc8e472599e89ae4ec8c9b Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 5 Mar 2025 15:22:14 -0800 Subject: [PATCH 01/13] Add block128 support to dp4a --- .../contrib_ops/webgpu/quantization/matmul_nbits.cc | 12 +++++++----- .../contrib_ops/webgpu/quantization/matmul_nbits.h | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 1534fd26d3ad9..2f4a07fb74663 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -598,6 +598,8 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { // this shader require A to be int8 quantized with block size 64. B is regular // matmulnbits input with block size 32. + shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; + shader.AdditionalImplementation() << R"ADDNL_FN( const tile_size = 64; const subtile_size = 16; @@ -605,7 +607,6 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const vec_factor = 4; const u32_factor = 4; const tile_size_k_vec = 2; - const block_size = 32; // Shared memory var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 @@ -648,7 +649,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (col == 0) { // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2]; + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; } } @@ -826,7 +827,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal; - if (accuracy_level_ == 4 && block_size == 32 && + if (accuracy_level_ == 4 && block_size % 32 == 0 && batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 && !has_zero_points && use_dp4a && M >= kMinMForTileOptimization) { constexpr uint32_t kVec4Components = 4; @@ -849,7 +850,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context constexpr uint32_t kTileSize = 64; TensorShape reshaped_y_shape{1, M, N / kVec4Components}; - DP4AMatMulNBitsProgram mul_program; + DP4AMatMulNBitsProgram mul_program{block_size}; mul_program.SetWorkgroupSize(256); mul_program.SetDispatchGroupSize( (M + kTileSize - 1) / kTileSize, @@ -863,7 +864,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {static_cast(K)}, {static_cast(K / 8)}, {static_cast(K / 16)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}) + .CacheHint("Block" + std::to_string(block_size)); return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 3d72629bf6b25..9c203744596a0 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -44,7 +44,7 @@ class DP4AMatMulQuantizeProgram final : public Program { public: - DP4AMatMulNBitsProgram() : Program{"DP4AMatMulNBits"} {} + DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -52,6 +52,8 @@ class DP4AMatMulNBitsProgram final : public Program { {"K", ProgramUniformVariableDataType::Uint32}, {"K8", ProgramUniformVariableDataType::Uint32}, {"K16", ProgramUniformVariableDataType::Uint32}); + private: + uint32_t block_size_; }; class MatMulNBits final : public WebGpuKernel { From bb8877b46331bd8cda6d17e53b5793b873a6d6bc Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 5 Mar 2025 17:48:24 -0800 Subject: [PATCH 02/13] Support block size that are multiples of 32 in DP4A Matmul --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 326 ++++++++++++++++++ .../webgpu/quantization/dp4a_matmul_nbits.h | 55 +++ .../webgpu/quantization/matmul_nbits.cc | 299 +--------------- .../webgpu/quantization/matmul_nbits.h | 21 -- 4 files changed, 386 insertions(+), 315 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc new file mode 100644 index 0000000000000..86adbe86bfd91 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -0,0 +1,326 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("scales", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << R"ADDNL_FN( + fn readInput(offset: u32) -> input_a_value_t + { + if (offset > uniforms.input_size) { + return input_a_value_t(0); + } + return input_a[offset]; + } + )ADDNL_FN"; + shader.MainFunctionBody() << R"MAIN_FN( + var local_a : array, 32>; + var max_value:vec4 = vec4(0); + for (var idx:u32=0;idx<32;idx+=1) + { + local_a[idx] = readInput(workgroup_idx*32 + idx); + max_value = max(max_value, abs(local_a[idx])); + } + var scale = max(max_value.x, max_value.y); + scale = max(scale, max_value.z); + scale = max(scale, max_value.w); + for (var idx:u32=0;idx<32;idx+=1) + { + output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); + } + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_idx] = scale/127; + )MAIN_FN"; + return Status::OK(); +} + +Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + // This shader implements co-operative matrix multiply. The key idea here is to + // assume there is a primitive for medium size matrix multiply a subgroup can perform, + // using all its lanes and pooling all its registers to keep the values in registry. + // + // The entire workgroup which has N subgroups first loads a tile into shared memory, + // Then each subgroup loads a subtile from shared memory into registers and uses + // the medium size matrix multiply primitive to perform the math. + // The values for tile/subtile size are chosen to conform to the resource limits + // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - + // therefore there are 16 subgroups and 16 lanes in each subgroup. + // K the hidden dimension is paged in from RAM at k tile size which is 64. + // All this puts the shared memory requirement slightly above 16KB. + // WebGPU limit is 16KB, output is moved to registers instead of SHM to make + // everything fit in shared memory. + // + // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with + // subgroup shuffle as a placeholder for the day the medium matrix mul primitive + // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on + // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the + // 512B of registry from each lane. + // + // The medium size matmul is implemented using dot4I8Packed, so the inputs for + // this shader require A to be int8 quantized with block size 64. B is regular + // matmulnbits input with block size 32. + + shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_size = 64; + const subtile_size = 16; + const tile_size_k = 32; + const vec_factor = 4; + const u32_factor = 4; + const tile_size_k_vec = 2; + + // Shared memory + var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_A : array; // 64 x 1 + var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 + var scale_B : array; // 64 x 1 + + fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let a_global = a_global_base + row; + if (a_global >= uniforms.M) + { + return; + } + tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; + if (col == 0) + { + // kidx_v - covers 16 values of k + scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; + } + } + + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + if (col == 0) + { + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; + } + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } + )ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + // During the load phase we use all 256 threads to load 64 rows of A/B. + // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. + let a_global_base = workgroup_id.x * tile_size; + let b_global_base = workgroup_id.y * tile_size; + let load_AorB = u32(local_idx/128); + let load_row = u32((local_idx%128)/2); + let load_col = u32(local_idx%2); + + // During the compute phase, we have the 64x64 tile split into + // subtiles of 16x16. We have a grid of 4x4 subtiles. + let subtile_id = u32(local_idx / subtile_size); + let subtile_idx = u32(subtile_id / 4); + let subtile_idy = u32(subtile_id % 4); + let base_A = subtile_idx * 16; + let base_B = subtile_idy * 16; + // For each subtile we have 16 threads assigned. + let a_idx = u32(local_idx % subtile_size); + + var lane_output1: vec4; + var lane_output2: vec4; + var lane_output3: vec4; + var lane_output4: vec4; + // K's vectrorization is 16 items per index. See input_a/input_b. + // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is + // k tile size is 32. In vectorized space that is 32/16 = 2. + for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) + { + // Load Phase: Populate shared memory for the workgroup. + if (load_AorB == 0) + { + loadSHMA(a_global_base, kidx_v, load_row, load_col); + } + else + { + loadSHMB(b_global_base, kidx_v, load_row, load_col); + } + workgroupBarrier(); + + // Compute phase: Perform matmul for this subtile 16 x 32 x 16. + // Step 1: Load from shared memory into registers across entire subgroup. + var own_a0: vec4 = tile_A[0][base_A + a_idx]; + var own_a1: vec4 = tile_A[1][base_A + a_idx]; + var own_scale_a: output_element_t = scale_A[base_A + a_idx]; + if (sg_size == 16) + { + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); + } + else + { + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); + } + workgroupBarrier(); + } + + let a_global = a_global_base + base_A + a_idx; + let b_global = b_global_base + base_B; + let output_idx = ((a_global) * uniforms.N + b_global)/4; + // This creates a shader requirement that uniforms.N % 16 == 0 + if (a_global < uniforms.M && b_global < uniforms.N) + { + output[output_idx] = lane_output1; + output[output_idx+1] = lane_output2; + output[output_idx+2] = lane_output3; + output[output_idx+3] = lane_output4; + } + )MAIN_FN"; + + return Status::OK(); +} + +Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + constexpr uint32_t kVec4Components = 4; + constexpr uint32_t kVec2Components = 2; + constexpr uint32_t kU32Components = 4; + + constexpr uint32_t kBlockSizeA = 128; + DP4AMatMulQuantizeProgram quantize_program; + quantize_program.SetWorkgroupSize(1); + quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); + TensorShape a_quant_shape{1, M, K / kU32Components}; + Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); + TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); + Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); + quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) + .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, + {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) + .AddUniformVariable({static_cast(M * K / kVec4Components)}); + ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + + constexpr uint32_t kTileSize = 64; + TensorShape reshaped_y_shape{1, M, N / kVec4Components}; + DP4AMatMulNBitsProgram mul_program{block_size}; + mul_program.SetWorkgroupSize(256); + mul_program.SetDispatchGroupSize( + (M + kTileSize - 1) / kTileSize, + (N + kTileSize - 1) / kTileSize, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 8)}, + {static_cast(K / 16)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}) + .CacheHint("Block" + std::to_string(block_size)); + return context.RunProgram(mul_program); +} + +bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points) { + // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. + // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 + bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && + context.AdapterInfo().backendType != wgpu::BackendType::Metal; + return (accuracy_level == 4 && block_size % 32 == 0 && + batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 && + !has_zero_points && use_dp4a); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h new file mode 100644 index 0000000000000..8f2e76a9f9a24 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class DP4AMatMulQuantizeProgram final : public Program { +public: + DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); +}; + +class DP4AMatMulNBitsProgram final : public Program { +public: + DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K8", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}); + private: + uint32_t block_size_; +}; + +Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y); + +bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points); + +} // namespace onnxruntime +} // namespace contrib +} // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 2f4a07fb74663..e10a7f551eec9 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -5,6 +5,7 @@ #include "contrib_ops/webgpu/quantization/matmul_nbits.h" #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" +#include "contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" @@ -532,256 +533,6 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << R"ADDNL_FN( - fn readInput(offset: u32) -> input_a_value_t - { - if (offset > uniforms.input_size) { - return input_a_value_t(0); - } - return input_a[offset]; - } -)ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( - var local_a : array, 32>; - var max_value:vec4 = vec4(0); - for (var idx:u32=0;idx<32;idx+=1) - { - local_a[idx] = readInput(workgroup_idx*32 + idx); - max_value = max(max_value, abs(local_a[idx])); - } - var scale = max(max_value.x, max_value.y); - scale = max(scale, max_value.z); - scale = max(scale, max_value.w); - for (var idx:u32=0;idx<32;idx+=1) - { - output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); - } - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_idx] = scale/127; -)MAIN_FN"; - return Status::OK(); -} - -Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("scales_a", ShaderUsage::UseUniform); - shader.AddInput("input_b", ShaderUsage::UseUniform); - shader.AddInput("scales_b", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - - // This shader implements co-operative matrix multiply. The key idea here is to - // assume there is a primitive for medium size matrix multiply a subgroup can perform, - // using all its lanes and pooling all its registers to keep the values in registry. - // - // The entire workgroup which has N subgroups first loads a tile into shared memory, - // Then each subgroup loads a subtile from shared memory into registers and uses - // the medium size matrix multiply primitive to perform the math. - // The values for tile/subtile size are chosen to conform to the resource limits - // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - - // therefore there are 16 subgroups and 16 lanes in each subgroup. - // K the hidden dimension is paged in from RAM at k tile size which is 64. - // All this puts the shared memory requirement slightly above 16KB. - // WebGPU limit is 16KB, output is moved to registers instead of SHM to make - // everything fit in shared memory. - // - // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with - // subgroup shuffle as a placeholder for the day the medium matrix mul primitive - // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on - // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the - // 512B of registry from each lane. - // - // The medium size matmul is implemented using dot4I8Packed, so the inputs for - // this shader require A to be int8 quantized with block size 64. B is regular - // matmulnbits input with block size 32. - - shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; - - shader.AdditionalImplementation() << R"ADDNL_FN( - const tile_size = 64; - const subtile_size = 16; - const tile_size_k = 32; - const vec_factor = 4; - const u32_factor = 4; - const tile_size_k_vec = 2; - - // Shared memory - var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_A : array; // 64 x 1 - var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_B : array; // 64 x 1 - - fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let a_global = a_global_base + row; - if (a_global >= uniforms.M) - { - return; - } - tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; - if (col == 0) - { - // kidx_v - covers 16 values of k - scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; - } - } - - fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let b_global = b_global_base + row; - if (b_global >= uniforms.N) - { - return; - } - - let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); - var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); - b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); - tile_B[col][row][2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); - tile_B[col][row][3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); - if (col == 0) - { - // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)]; - } - } - - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } -)ADDNL_FN"; - - shader.MainFunctionBody() << R"MAIN_FN( - // During the load phase we use all 256 threads to load 64 rows of A/B. - // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. - let a_global_base = workgroup_id.x * tile_size; - let b_global_base = workgroup_id.y * tile_size; - let load_AorB = u32(local_idx/128); - let load_row = u32((local_idx%128)/2); - let load_col = u32(local_idx%2); - - // During the compute phase, we have the 64x64 tile split into - // subtiles of 16x16. We have a grid of 4x4 subtiles. - let subtile_id = u32(local_idx / subtile_size); - let subtile_idx = u32(subtile_id / 4); - let subtile_idy = u32(subtile_id % 4); - let base_A = subtile_idx * 16; - let base_B = subtile_idy * 16; - // For each subtile we have 16 threads assigned. - let a_idx = u32(local_idx % subtile_size); - - var lane_output1: vec4; - var lane_output2: vec4; - var lane_output3: vec4; - var lane_output4: vec4; - // K's vectrorization is 16 items per index. See input_a/input_b. - // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is - // k tile size is 32. In vectorized space that is 32/16 = 2. - for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) - { - // Load Phase: Populate shared memory for the workgroup. - if (load_AorB == 0) - { - loadSHMA(a_global_base, kidx_v, load_row, load_col); - } - else - { - loadSHMB(b_global_base, kidx_v, load_row, load_col); - } - workgroupBarrier(); - - // Compute phase: Perform matmul for this subtile 16 x 32 x 16. - // Step 1: Load from shared memory into registers across entire subgroup. - var own_a0: vec4 = tile_A[0][base_A + a_idx]; - var own_a1: vec4 = tile_A[1][base_A + a_idx]; - var own_scale_a: output_element_t = scale_A[base_A + a_idx]; - if (sg_size == 16) - { - var own_b0: vec4 = tile_B[0][base_B + sg_id]; - var own_b1: vec4 = tile_B[1][base_B + sg_id]; - var own_scale_b: output_element_t = scale_B[base_B + sg_id]; - // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. - lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); - lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); - lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); - lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); - - lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); - lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); - lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); - lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); - - lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); - lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); - lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); - lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); - - lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); - lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); - lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); - lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); - } - else - { - // Code for other subgroup sizes, simply doesnt use subgroups at all. - // Relies on reads from single location tile_B[][base_B + col] by all - // being optimized by the hardware. - lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); - lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); - lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); - lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); - - lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); - lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); - lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); - lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); - - lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); - lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); - lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); - lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); - - lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); - lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); - lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); - lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); - } - workgroupBarrier(); - } - - let a_global = a_global_base + base_A + a_idx; - let b_global = b_global_base + base_B; - let output_idx = ((a_global) * uniforms.N + b_global)/4; - // This creates a shader requirement that uniforms.N % 16 == 0 - if (a_global < uniforms.M && b_global < uniforms.N) - { - output[output_idx] = lane_output1; - output[output_idx+1] = lane_output2; - output[output_idx+2] = lane_output3; - output[output_idx+3] = lane_output4; - } -)MAIN_FN"; - - return Status::OK(); -} - Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -823,55 +574,15 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y); } - const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); - // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. - // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 - const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal; - if (accuracy_level_ == 4 && block_size % 32 == 0 && - batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 && - !has_zero_points && use_dp4a && M >= kMinMForTileOptimization) { - constexpr uint32_t kVec4Components = 4; - constexpr uint32_t kVec2Components = 2; - constexpr uint32_t kU32Components = 4; - - constexpr uint32_t kBlockSizeA = 128; - DP4AMatMulQuantizeProgram quantize_program; - quantize_program.SetWorkgroupSize(1); - quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); - TensorShape a_quant_shape{1, M, K / kU32Components}; - Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); - TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); - Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); - quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) - .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) - .AddUniformVariable({static_cast(M * K / kVec4Components)}); - ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - - constexpr uint32_t kTileSize = 64; - TensorShape reshaped_y_shape{1, M, N / kVec4Components}; - DP4AMatMulNBitsProgram mul_program{block_size}; - mul_program.SetWorkgroupSize(256); - mul_program.SetDispatchGroupSize( - (M + kTileSize - 1) / kTileSize, - (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, - {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) - .AddUniformVariables({{static_cast(M)}, - {static_cast(N)}, - {static_cast(K)}, - {static_cast(K / 8)}, - {static_cast(K / 16)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size)); - return context.RunProgram(mul_program); + if (M >= kMinMForTileOptimization && + CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) { + return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y); } // TODO: Support output_number > 1. Some cases are failed when output_number > 1. constexpr uint32_t output_number = 1; const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1; + const bool has_subgroup = context.Device().HasFeature(wgpu::FeatureName::Subgroups); const bool use_subgroup = has_subgroup && context.AdapterInfo().vendor == std::string_view{"intel"} && components_a == 4 && block_size == 32; MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow(components_b), has_zero_points, use_subgroup}; if (M > kMinMForTileOptimization && block_size == 32) { diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index 9c203744596a0..10221e19c7400 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -35,27 +35,6 @@ class MatMulNBitsProgram final : public Program { bool use_subgroup_; }; -class DP4AMatMulQuantizeProgram final : public Program { - public: - DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); -}; - -class DP4AMatMulNBitsProgram final : public Program { - public: - DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"M", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"K8", ProgramUniformVariableDataType::Uint32}, - {"K16", ProgramUniformVariableDataType::Uint32}); - private: - uint32_t block_size_; -}; - class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) { From 82e48432dc7d3d5a92e195160b69affd52ea551b Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 17:59:27 -0800 Subject: [PATCH 03/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 86adbe86bfd91..1b8a67968cb1d 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -9,10 +9,10 @@ namespace contrib { namespace webgpu { Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); - shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << R"ADDNL_FN( + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("scales", ShaderUsage::UseUniform); + shader.AdditionalImplementation() << R"ADDNL_FN( fn readInput(offset: u32) -> input_a_value_t { if (offset > uniforms.input_size) { From 77d217384e1e736468b6cb4ca8e03597330da85b Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 17:59:35 -0800 Subject: [PATCH 04/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 1b8a67968cb1d..a7b794917a52e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -21,7 +21,7 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const return input_a[offset]; } )ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( + shader.MainFunctionBody() << R"MAIN_FN( var local_a : array, 32>; var max_value:vec4 = vec4(0); for (var idx:u32=0;idx<32;idx+=1) From f89313083ac840886977c284404a5d9baf0dc1e9 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 17:59:44 -0800 Subject: [PATCH 05/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index a7b794917a52e..525d513fb7e5f 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -39,7 +39,7 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. scales[workgroup_idx] = scale/127; )MAIN_FN"; - return Status::OK(); + return Status::OK(); } Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { From 48f4b60d1a9657781551215f75367b52cf6a7ac0 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 17:59:55 -0800 Subject: [PATCH 06/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 525d513fb7e5f..0337f2ab01fef 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -43,40 +43,40 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const } Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - shader.AddInput("scales_a", ShaderUsage::UseUniform); - shader.AddInput("input_b", ShaderUsage::UseUniform); - shader.AddInput("scales_b", ShaderUsage::UseUniform); - shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - // This shader implements co-operative matrix multiply. The key idea here is to - // assume there is a primitive for medium size matrix multiply a subgroup can perform, - // using all its lanes and pooling all its registers to keep the values in registry. - // - // The entire workgroup which has N subgroups first loads a tile into shared memory, - // Then each subgroup loads a subtile from shared memory into registers and uses - // the medium size matrix multiply primitive to perform the math. - // The values for tile/subtile size are chosen to conform to the resource limits - // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - - // therefore there are 16 subgroups and 16 lanes in each subgroup. - // K the hidden dimension is paged in from RAM at k tile size which is 64. - // All this puts the shared memory requirement slightly above 16KB. - // WebGPU limit is 16KB, output is moved to registers instead of SHM to make - // everything fit in shared memory. - // - // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with - // subgroup shuffle as a placeholder for the day the medium matrix mul primitive - // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on - // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the - // 512B of registry from each lane. - // - // The medium size matmul is implemented using dot4I8Packed, so the inputs for - // this shader require A to be int8 quantized with block size 64. B is regular - // matmulnbits input with block size 32. + // This shader implements co-operative matrix multiply. The key idea here is to + // assume there is a primitive for medium size matrix multiply a subgroup can perform, + // using all its lanes and pooling all its registers to keep the values in registry. + // + // The entire workgroup which has N subgroups first loads a tile into shared memory, + // Then each subgroup loads a subtile from shared memory into registers and uses + // the medium size matrix multiply primitive to perform the math. + // The values for tile/subtile size are chosen to conform to the resource limits + // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - + // therefore there are 16 subgroups and 16 lanes in each subgroup. + // K the hidden dimension is paged in from RAM at k tile size which is 64. + // All this puts the shared memory requirement slightly above 16KB. + // WebGPU limit is 16KB, output is moved to registers instead of SHM to make + // everything fit in shared memory. + // + // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with + // subgroup shuffle as a placeholder for the day the medium matrix mul primitive + // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on + // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the + // 512B of registry from each lane. + // + // The medium size matmul is implemented using dot4I8Packed, so the inputs for + // this shader require A to be int8 quantized with block size 64. B is regular + // matmulnbits input with block size 32. - shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; + shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";"; - shader.AdditionalImplementation() << R"ADDNL_FN( + shader.AdditionalImplementation() << R"ADDNL_FN( const tile_size = 64; const subtile_size = 16; const tile_size_k = 32; From e709ef743a46062360ec5a9a64abb63a43883446 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:00:03 -0800 Subject: [PATCH 07/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 8f2e76a9f9a24..3317686512003 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -13,10 +13,10 @@ namespace webgpu { using namespace onnxruntime::webgpu; class DP4AMatMulQuantizeProgram final : public Program { -public: - DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); + public: + DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}); }; class DP4AMatMulNBitsProgram final : public Program { From ec485d539f8a996cfb454210123fc9aabdd0769c Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:00:09 -0800 Subject: [PATCH 08/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../webgpu/quantization/dp4a_matmul_nbits.h | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 3317686512003..5eb3c173b222e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -20,17 +20,18 @@ class DP4AMatMulQuantizeProgram final : public Program { -public: - DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} - Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"M", ProgramUniformVariableDataType::Uint32}, - {"N", ProgramUniformVariableDataType::Uint32}, - {"K", ProgramUniformVariableDataType::Uint32}, - {"K8", ProgramUniformVariableDataType::Uint32}, - {"K16", ProgramUniformVariableDataType::Uint32}); - private: - uint32_t block_size_; + public: + DP4AMatMulNBitsProgram(uint32_t block_size) : Program{"DP4AMatMulNBits"}, block_size_(block_size) {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K8", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}); + + private: + uint32_t block_size_; }; Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, From bca1723610d37af26caaaaee427087b836be1cd8 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:00:15 -0800 Subject: [PATCH 09/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../webgpu/quantization/dp4a_matmul_nbits.h | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 5eb3c173b222e..15b86d78301ad 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -35,22 +35,22 @@ class DP4AMatMulNBitsProgram final : public Program { }; Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - uint32_t M, - uint32_t N, - uint32_t K, - uint32_t block_size, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y); + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y); bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, - uint64_t accuracy_level, - uint32_t block_size, - uint32_t batch_count, - uint32_t N, - uint32_t K, - uint32_t components_k, - bool has_zero_points); + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points); -} // namespace onnxruntime -} // namespace contrib } // namespace webgpu +} // namespace contrib +} // namespace onnxruntime From d8bfc360a007781a925605408e0268af6f89ac53 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:00:30 -0800 Subject: [PATCH 10/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 0337f2ab01fef..99fb76fb7d14b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -305,20 +305,20 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor } bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, - uint64_t accuracy_level, - uint32_t block_size, - uint32_t batch_count, - uint32_t N, - uint32_t K, - uint32_t components_k, - bool has_zero_points) { - // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. - // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 - bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && - context.AdapterInfo().backendType != wgpu::BackendType::Metal; - return (accuracy_level == 4 && block_size % 32 == 0 && - batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 && - !has_zero_points && use_dp4a); + uint64_t accuracy_level, + uint32_t block_size, + uint32_t batch_count, + uint32_t N, + uint32_t K, + uint32_t components_k, + bool has_zero_points) { + // macOS - Avoid using dp4a on Metal, as it does not appear to have native dp4a support. + // https://github.com/gpuweb/gpuweb/issues/2677#issuecomment-1713292226 + bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && + context.AdapterInfo().backendType != wgpu::BackendType::Metal; + return (accuracy_level == 4 && block_size % 32 == 0 && + batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 && + !has_zero_points && use_dp4a); } } // namespace webgpu From e40bd53feb5465f42a1343b5df36cb55d256ab09 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:00:38 -0800 Subject: [PATCH 11/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 99fb76fb7d14b..59830ff6c0c1f 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -299,9 +299,9 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {static_cast(K)}, {static_cast(K / 8)}, {static_cast(K / 16)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size)); - return context.RunProgram(mul_program); + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}) + .CacheHint("Block" + std::to_string(block_size)); + return context.RunProgram(mul_program); } bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, From 1d149f0679d2007895887f19dbd3a77bfdcee8bb Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Wed, 5 Mar 2025 18:00:44 -0800 Subject: [PATCH 12/13] Update onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 59830ff6c0c1f..051b7b28b4120 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -144,7 +144,7 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } )ADDNL_FN"; - shader.MainFunctionBody() << R"MAIN_FN( + shader.MainFunctionBody() << R"MAIN_FN( // During the load phase we use all 256 threads to load 64 rows of A/B. // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. let a_global_base = workgroup_id.x * tile_size; From 10caf95db7f8a05dc34a3cd6405490cffa75e6f2 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar Date: Wed, 5 Mar 2025 18:04:34 -0800 Subject: [PATCH 13/13] lint runner --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 051b7b28b4120..6720a6072f7bb 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -255,46 +255,46 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { } )MAIN_FN"; - return Status::OK(); + return Status::OK(); } Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, - uint32_t M, - uint32_t N, - uint32_t K, - uint32_t block_size, - onnxruntime::webgpu::ComputeContext& context, - Tensor* y) { - constexpr uint32_t kVec4Components = 4; - constexpr uint32_t kVec2Components = 2; - constexpr uint32_t kU32Components = 4; + uint32_t M, + uint32_t N, + uint32_t K, + uint32_t block_size, + onnxruntime::webgpu::ComputeContext& context, + Tensor* y) { + constexpr uint32_t kVec4Components = 4; + constexpr uint32_t kVec2Components = 2; + constexpr uint32_t kU32Components = 4; - constexpr uint32_t kBlockSizeA = 128; - DP4AMatMulQuantizeProgram quantize_program; - quantize_program.SetWorkgroupSize(1); - quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); - TensorShape a_quant_shape{1, M, K / kU32Components}; - Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); - TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); - Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); - quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) - .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) - .AddUniformVariable({static_cast(M * K / kVec4Components)}); - ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); + constexpr uint32_t kBlockSizeA = 128; + DP4AMatMulQuantizeProgram quantize_program; + quantize_program.SetWorkgroupSize(1); + quantize_program.SetDispatchGroupSize(M * K / kBlockSizeA, 1, 1); + TensorShape a_quant_shape{1, M, K / kU32Components}; + Tensor a_quant = context.CreateGPUTensor(DataTypeImpl::GetType(), a_quant_shape); + TensorShapeVector a_scales_dims({1, 1, M, K / kBlockSizeA}); + Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); + quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) + .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, + {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) + .AddUniformVariable({static_cast(M * K / kVec4Components)}); + ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - constexpr uint32_t kTileSize = 64; - TensorShape reshaped_y_shape{1, M, N / kVec4Components}; - DP4AMatMulNBitsProgram mul_program{block_size}; - mul_program.SetWorkgroupSize(256); - mul_program.SetDispatchGroupSize( - (M + kTileSize - 1) / kTileSize, - (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, - {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) - .AddUniformVariables({{static_cast(M)}, + constexpr uint32_t kTileSize = 64; + TensorShape reshaped_y_shape{1, M, N / kVec4Components}; + DP4AMatMulNBitsProgram mul_program{block_size}; + mul_program.SetWorkgroupSize(256); + mul_program.SetDispatchGroupSize( + (M + kTileSize - 1) / kTileSize, + (N + kTileSize - 1) / kTileSize, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, {static_cast(N)}, {static_cast(K)}, {static_cast(K / 8)},