diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index c8e77d14117bf..d22acdbe0e782 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -123,21 +123,42 @@ export const createMatMulNBitsProgramInfo = ( } })(); + // Number of quantized values per u32 word and passes needed (each pass extracts 8 values). + const valuesPerWord = Math.floor(32 / attributes.bits); // Q4=8, Q2=16 + const passesPerWord = Math.floor(valuesPerWord / 8); // Q4=1, Q2=2 + const processOneWord = (): string => { - let calcStr = ` - // reuse a data - var input_offset = ${a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`)}; - var a_data: ${qDqDataType}; - for (var j: u32 = 0; j < ${8 / aComponents}; j++) { - a_data[j] = ${a.getByOffset('input_offset')}; - input_offset++; + let calcStr = ''; + for (let pass = 0; pass < passesPerWord; pass++) { + // Each pass processes 8 values from the current u32 word. + // For Q4 (pass=0): shift by 0 and 4. For Q2 (pass 0: shift 0,2; pass 1: shift 4,6). + const lowerShift = pass * attributes.bits * 4; // bit offset for lower group within each byte + const upperShift = lowerShift + attributes.bits; + calcStr += ` + // reuse a data (pass ${pass}) + var input_offset${pass > 0 ? pass : ''} = ${pass === 0 ? a.indicesToOffset(`${a.type.indices}(batch, row, word_offset)`) : `input_offset`}; + var a_data${pass > 0 ? pass : ''}: ${qDqDataType}; + for (var j${pass > 0 ? pass : ''}: u32 = 0; j${pass > 0 ? pass : ''} < ${8 / aComponents}; j${pass > 0 ? pass : ''}++) { + a_data${pass > 0 ? pass : ''}[j${pass > 0 ? pass : ''}] = ${a.getByOffset(`input_offset${pass > 0 ? pass : ''}`)}; + input_offset${pass > 0 ? pass : ''}++; } `; - for (let c = 0; c < components * outputNumber; c++) { - calcStr += ` + for (let c = 0; c < components * outputNumber; c++) { + calcStr += ` b_value = ${bComponents === 1 ? `b${c}_data` : `b${c}_data[i]`}; - b_value_lower = unpack4xU8(b_value & b_mask); - b_value_upper = unpack4xU8((b_value >> 4) & b_mask); + ${ + attributes.bits === 2 + ? `{ + let half_word = b_value >> ${pass * 16}u; + let byte_lo = half_word & 0xFFu; + let byte_hi = (half_word >> 8u) & 0xFFu; + let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u); + b_value_lower = unpack4xU8(spread_word & b_mask); + b_value_upper = unpack4xU8((spread_word >> 2u) & b_mask); + }` + : `b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & b_mask); + b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & b_mask);` + } b_quantized_values = ${qDqDataType}(${Array.from( { length: 4 }, (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, @@ -159,11 +180,12 @@ export const createMatMulNBitsProgramInfo = ( (_, i) => `${ aComponents === 1 - ? `a_data[${i}] * b_dequantized_values[${i}]` - : `dot(a_data[${i}], b_dequantized_values[${i}])` + ? `a_data${pass > 0 ? pass : ''}[${i}] * b_dequantized_values[${i}]` + : `dot(a_data${pass > 0 ? pass : ''}[${i}], b_dequantized_values[${i}])` }`, ).join(' + ')}; `; + } } return calcStr; }; @@ -173,16 +195,17 @@ export const createMatMulNBitsProgramInfo = ( ${ zeroPoints ? ` - let zero_point_bytes_per_col = (nBlocksPerCol + 1) / 2; + let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u; + let zero_point_bytes_per_col = (nBlocksPerCol + zero_point_values_per_byte - 1u) / zero_point_values_per_byte; var zero_point_byte_count: u32; var zero_point_word_index: u32; var zero_point_byte_offset: u32; - let zero_point_nibble_offset: u32 = block & 0x1u; + let zero_point_sub_offset: u32 = block % zero_point_values_per_byte; var zero_point_bits_offset: u32; var zero_point_word: u32;` : ` - // The default zero point is 8 for unsigned 4-bit quantization. - let zero_point = ${dataType}(${8.0});` + // The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization. + let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});` } `; for (let c = 0; c < components * outputNumber; c++) { @@ -191,12 +214,12 @@ export const createMatMulNBitsProgramInfo = ( ${ zeroPoints ? ` - zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u); + zero_point_byte_count = col_index * zero_point_bytes_per_col + (block / zero_point_values_per_byte); zero_point_word_index = zero_point_byte_count >> 0x2u; zero_point_byte_offset = zero_point_byte_count & 0x3u; - zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u); zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset; - let zero_point${c} = ${dataType}((zero_point_word) & 0xFu);` + let zero_point${c} = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});` : '' } col_index += 1;`; @@ -212,7 +235,7 @@ export const createMatMulNBitsProgramInfo = ( } calcStr += ` var b_value: u32; - let b_mask: u32 = 0x0F0F0F0Fu; + let b_mask: u32 = ${attributes.bits === 2 ? '0x03030303u' : '0x0F0F0F0Fu'}; var b_value_lower: vec4; var b_value_upper: vec4; var b_quantized_values: ${qDqDataType}; @@ -237,7 +260,7 @@ export const createMatMulNBitsProgramInfo = ( ${prepareBData()} for (var i: u32 = 0; i < ${bComponents}; i++) { ${processOneWord()} - word_offset += ${8 / aComponents}; + word_offset += ${valuesPerWord / aComponents}; } } } @@ -291,7 +314,8 @@ export const createMatMulNBitsBlockSize32ProgramInfo = ( const workgroupSize = 128; const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1; const workgroupX = workgroupSize / workgroupY; - const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data. + const valuesPerWordBs32 = Math.floor(32 / attributes.bits); // Q4=8, Q2=16 + const tileSize = workgroupX * bComponents * valuesPerWordBs32; // each uint32 has valuesPerWord data. const aLengthPerTile = tileSize / aComponents; const blocksPerTile = tileSize / attributes.blockSize; const dispatchSize = ShapeUtil.size(outputShape) / workgroupY; @@ -376,36 +400,59 @@ export const createMatMulNBitsBlockSize32ProgramInfo = ( ${ zeroPoints ? ` - let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2; - let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u); + let zero_point_values_per_byte: u32 = ${Math.floor(8 / attributes.bits)}u; + let zero_point_bytes_per_col = (n_blocks_per_col + zero_point_values_per_byte - 1u) / zero_point_values_per_byte; + let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block / zero_point_values_per_byte); let zero_point_word_index = zero_point_byte_count >> 0x2u; let zero_point_byte_offset = zero_point_byte_count & 0x3u; - let zero_point_nibble_offset: u32 = block & 0x1u; - let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + let zero_point_sub_offset: u32 = block % zero_point_values_per_byte; + let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_sub_offset * ${attributes.bits}u); let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset; - let zero_point = ${dataType}((zero_point_word) & 0xFu);` + let zero_point = ${dataType}((zero_point_word) & ${attributes.bits === 2 ? '0x3u' : '0xFu'});` : ` - // The default zero point is 8 for unsigned 4-bit quantization. - let zero_point = ${dataType}(${8.0});` + // The default zero point is ${Math.pow(2, attributes.bits - 1)} for unsigned ${attributes.bits}-bit quantization. + let zero_point = ${dataType}(${Math.pow(2, attributes.bits - 1).toFixed(1)});` } let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)}; let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)}; var word_offset = local_id.x * ${attributes.blockSize / aComponents}; for (var i: u32 = 0; i < ${bComponents}; i++) { - ${readA()} let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`}; - let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); - let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu); - let b_quantized_values = mat2x4<${dataType}>(${Array.from( - { length: 4 }, - (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, - ).join(', ')}); - let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale; - inter_results[local_id.y][local_id.x] += ${Array.from( - { length: 2 }, - (_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`, - ).join(' + ')}; - word_offset += ${8 / aComponents}; + ${(() => { + const passesPerWordBs32 = Math.floor(valuesPerWordBs32 / 8); + let code = ''; + for (let pass = 0; pass < passesPerWordBs32; pass++) { + const lowerShift = pass * attributes.bits * 4; + const upperShift = lowerShift + attributes.bits; + code += ` + ${readA()} + {${ + attributes.bits === 2 + ? ` + let half_word = b_value >> ${pass * 16}u; + let byte_lo = half_word & 0xFFu; + let byte_hi = (half_word >> 8u) & 0xFFu; + let spread_word = (byte_lo & 0xFu) | ((byte_lo >> 4u) << 8u) | ((byte_hi & 0xFu) << 16u) | ((byte_hi >> 4u) << 24u); + let b_value_lower = unpack4xU8(spread_word & 0x03030303u); + let b_value_upper = unpack4xU8((spread_word >> 2u) & 0x03030303u);` + : ` + let b_value_lower = unpack4xU8((b_value >> ${lowerShift}u) & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> ${upperShift}u) & 0x0F0F0F0Fu);` + } + let b_quantized_values = mat2x4<${dataType}>(${Array.from( + { length: 4 }, + (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, + ).join(', ')}); + let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale; + inter_results[local_id.y][local_id.x] += ${Array.from( + { length: 2 }, + (_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`, + ).join(' + ')}; + } + word_offset += ${8 / aComponents};`; + } + return code; + })()} } workgroupBarrier(); } diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h index cca2c4757765b..d27e222dec678 100644 --- a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -17,8 +17,8 @@ class MatMulNBits final : public JsKernel { accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, nbits_{narrow(info.GetAttr("bits"))}, block_size_{narrow(info.GetAttr("block_size"))} { - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(nbits_ == 4 || nbits_ == 2, + "Only 2b and 4b quantization is supported for MatMulNBits op, additional bits support is planned."); ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), "Block size must be a power of 2 and greater than or equal to 16."); JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 56da24d7844d3..a2b89dfcee4d6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -117,10 +117,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context ORT_ENFORCE(g_idx == nullptr, "group_idx as input is not supported yet."); const bool has_zero_points = zero_points != nullptr; - const uint32_t nbits = onnxruntime::narrow(bits_); if (has_zero_points) { ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType()); - ORT_ENFORCE(nbits != 2, "Currently, zero points are not supported for Q2 quantization."); } MatMulComputeHelper helper; @@ -205,9 +203,13 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, const uint32_t components_a = GetMaxComponents(K); const uint32_t components_b = GetMaxComponents(blob_size_in_words); uint32_t components = GetMaxComponents(N); - // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. So here we need to check whether n_blocks_per_col is divisible by 8/nbits. - // For bits==4, this is counted by elements of uint4. Need add 1 if not divisible by 2. - uint32_t zero_blocks_per_col = n_blocks_per_col % (8 / nbits) == 0 ? n_blocks_per_col : n_blocks_per_col + 1; + // zero_points has shape[N * CeilDiv(n_blocks_per_col * bits, 8)]. + // The shader uses a flat linear index to address individual n-bit zero point values. + // Since each column's zero points are byte-aligned in the packed buffer, we must round + // n_blocks_per_col up to the next multiple of (8/nbits) — the number of zero point + // values per byte — so that the linear stride correctly skips byte-boundary padding. + const uint32_t zp_elements_per_byte = 8 / static_cast(nbits); + uint32_t zero_blocks_per_col = (n_blocks_per_col + zp_elements_per_byte - 1) / zp_elements_per_byte * zp_elements_per_byte; #if !defined(__wasm__) int32_t subgroup_matrix_config_index = -1; @@ -219,7 +221,9 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, #endif // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. + // DP4A Q2 path uses a hardcoded LUT with zero_point=2, so skip DP4A for Q2 with custom zero points. if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && + !(has_zero_points && nbits == 2) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template index 0ad2f89f5263c..e017f97afff62 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template @@ -20,6 +20,7 @@ const bit_mask = 0xFFu; #elif n_bits == 2 const default_zero_point = 2; + const bit_mask = 0x3u; #endif #if has_zero_points diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 3d5e3e5f360b4..c2685c69db877 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -26,6 +26,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" +#include "core/providers/webgpu/webgpu_provider_options.h" extern std::unique_ptr ort_env; @@ -199,9 +200,7 @@ void RunTest2Bits(const TestOptions2Bits& opts) { std::vector> execution_providers; if constexpr (std::is_same::value) { #ifdef USE_WEBGPU - if (!opts.has_zero_point) { - execution_providers.push_back(DefaultWebGpuExecutionProvider()); - } + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif execution_providers.emplace_back(DefaultCpuExecutionProvider()); test.ConfigEps(std::move(execution_providers)); @@ -449,6 +448,100 @@ TEST(MatMul2Bits, Float32_2b_Accuracy4) { TestMatMul2BitsTyped(); } +#ifdef USE_WEBGPU + +namespace { + +// Runs a 2-bit MatMulNBits test on WebGPU EP with CPU as baseline. +// The test quantizes random weights to 2 bits, dequantizes to compute +// expected output via matmul on CPU, then compares WebGPU output. +template +void RunWebGpu2BitsTest(int64_t M, int64_t N, int64_t K, int64_t block_size, + bool has_zero_point, float abs_error = 0.1f, float rel_error = 0.02f) { + TestOptions2Bits opts{}; + opts.M = M; + opts.N = N; + opts.K = K; + opts.block_size = block_size; + opts.has_zero_point = has_zero_point; + opts.output_abs_error = abs_error; + opts.output_rel_error = rel_error; + + RunTest2Bits(opts); +} + +} // namespace + +// WebGPU 2-bit tests: symmetric (no zero points) +TEST(MatMul2BitsWebGpu, Float32_Symmetric_Small) { + RunWebGpu2BitsTest(1, 32, 32, 16, false); + RunWebGpu2BitsTest(1, 32, 32, 32, false); + RunWebGpu2BitsTest(1, 32, 16, 16, false); +} + +TEST(MatMul2BitsWebGpu, Float32_Symmetric_Medium) { + RunWebGpu2BitsTest(1, 288, 16, 16, false); + RunWebGpu2BitsTest(4, 32, 32, 16, false); + RunWebGpu2BitsTest(4, 288, 16, 16, false); + RunWebGpu2BitsTest(100, 32, 32, 16, false); + RunWebGpu2BitsTest(100, 288, 16, 16, false); +} + +// WebGPU 2-bit tests: asymmetric (with zero points) — the primary accuracy concern +TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_Small) { + RunWebGpu2BitsTest(1, 1, 16, 16, true); + RunWebGpu2BitsTest(1, 2, 16, 16, true); + RunWebGpu2BitsTest(1, 32, 16, 16, true); + RunWebGpu2BitsTest(1, 32, 32, 16, true); + RunWebGpu2BitsTest(1, 32, 32, 32, true); +} + +TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_Medium) { + RunWebGpu2BitsTest(1, 288, 16, 16, true); + RunWebGpu2BitsTest(4, 32, 32, 16, true); + RunWebGpu2BitsTest(4, 288, 16, 16, true); + RunWebGpu2BitsTest(100, 32, 32, 16, true); + RunWebGpu2BitsTest(100, 288, 16, 16, true); +} + +TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_BlockSize32) { + // blockSize=32 triggers the Intel Gen12 optimized path on matching hardware. + RunWebGpu2BitsTest(1, 32, 32, 32, true); + RunWebGpu2BitsTest(4, 32, 32, 32, true); + RunWebGpu2BitsTest(100, 32, 32, 32, true); +} + +TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_BlockSize128) { + RunWebGpu2BitsTest(1, 32, 16, 128, true); + RunWebGpu2BitsTest(4, 32, 16, 128, true); + RunWebGpu2BitsTest(100, 32, 16, 128, true); +} + +// BlockSize=64 tests — covers nBlocksPerCol not a multiple of 4 (padding edge case). +// These match configurations found in real 2-bit quantized transformer models. +TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_BlockSize64) { + RunWebGpu2BitsTest(1, 32, 64, 64, true); + RunWebGpu2BitsTest(1, 32, 128, 64, true); + RunWebGpu2BitsTest(1, 192, 384, 64, true, 0.3f, 0.05f); + RunWebGpu2BitsTest(1, 384, 1024, 64, true, 0.5f, 0.05f); +} + +TEST(MatMul2BitsWebGpu, Float32_Symmetric_BlockSize64) { + RunWebGpu2BitsTest(1, 32, 64, 64, false); + RunWebGpu2BitsTest(1, 32, 128, 64, false); + RunWebGpu2BitsTest(1, 192, 384, 64, false, 0.3f, 0.05f); +} + +// Larger K tests — exercises multi-word (multiple u32) extraction per block, +// verifying the Q2 nibble-spread and A-offset tracking across passes. +TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_LargerK) { + RunWebGpu2BitsTest(1, 32, 64, 32, true); + RunWebGpu2BitsTest(1, 32, 128, 32, true); + RunWebGpu2BitsTest(1, 32, 256, 32, true, 0.3f, 0.05f); +} + +#endif // USE_WEBGPU + } // namespace test } // namespace onnxruntime