diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index c79efee65e5c5..821bdd24764fb 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -790,6 +790,183 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } +// tile_N size = 16, workgroup size = 64, scale_A components = 4, b components = 4, output components = 4 +Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scales_a", ShaderUsage::UseUniform); + shader.AddInput("input_b", ShaderUsage::UseUniform); + shader.AddInput("scales_b", ShaderUsage::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); + + shader.AdditionalImplementation() << R"ADDNL_FN( + const tile_size = 16u; // tile_size = tile_size_vec * output components + const tile_size_vec = 4u; + const tile_size_k_vec = 16u; // tile_size_vec * tile_size_k_vec = workgroup size + + // Shared memory + var tile_A : array, 32>; // 256 + var scale_A : vec4; // 4 + var inter_results: array, tile_size_k_vec>, tile_size_vec>; + + fn loadSHMA(a_global:u32, kidx_v:u32, col: u32) + { + let k_offset = kidx_v + col; + if (k_offset > uniforms.K16) + { + return; + } + tile_A[col] = input_a[a_global*uniforms.K16+k_offset]; + if (col == 0) + { + // kidx_v - covers 16 values of k + scale_A = scales_a[a_global*(uniforms.K/512) + k_offset/32]; + } + } + + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(local_sum) * scale; + } +)ADDNL_FN"; + + shader.MainFunctionBody() << R"MAIN_FN( + let a_global = workgroup_id.y; + let b_global_base = workgroup_id.x * tile_size; + + let idx = local_idx % tile_size_k_vec; + let idy = local_idx / tile_size_k_vec; + + for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v+=16) + { + // Load Phase: Populate shared memory for the workgroup. + if (local_idx < 32) + { + loadSHMA(a_global, kidx_v * 2, local_idx); + } + workgroupBarrier(); + + var own_a: vec4 = tile_A[idx*2]; + var own_a1: vec4 = tile_A[idx*2 + 1]; + var own_scale_a: output_element_t = scale_A[idx / 4]; + + var own_b = vec4(0); + var own_b1 = vec4(0); + var own_scale_b = output_element_t(0); + let b_global = b_global_base + idy * 4; + let k_offset = kidx_v+idx; + if (b_global < uniforms.N && k_offset < uniforms.K32) + { var b_offset = b_global*uniforms.K32+k_offset; + var b_value = input_b[b_offset]; + var b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + var b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + own_scale_b = scales_b[b_offset]; + inter_results[idy][idx].x += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); + + b_offset = (b_global + 1)*uniforms.K32+k_offset; + b_value = input_b[b_offset]; + b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + own_scale_b = scales_b[b_offset]; + inter_results[idy][idx].y += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); + + b_offset = (b_global + 2)*uniforms.K32+k_offset; + b_value = input_b[b_offset]; + b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + own_scale_b = scales_b[b_offset]; + inter_results[idy][idx].z += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); + + b_offset = (b_global + 3)*uniforms.K32+k_offset; + b_value = input_b[b_offset]; + b_value_lower = vec4(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[0] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[1] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + b_value_lower = vec4(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4(8); + b_value_upper = vec4(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4(8); + own_b1[2] = pack4xI8(vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1])); + own_b1[3] = pack4xI8(vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3])); + own_scale_b = scales_b[b_offset]; + inter_results[idy][idx].w += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); + } + } + + workgroupBarrier(); + if (local_idx < tile_size_vec) { + var output_value = vec4(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + output_value += inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx * 4; + let output_idx = (a_global * uniforms.N + b_global)/4; + if (b_global < uniforms.N) { + output[output_idx] = output_value; + } + } +)MAIN_FN"; + + return Status::OK(); +} + Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* a = context.Input(0); const Tensor* b = context.Input(1); @@ -831,7 +1008,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal; if (accuracy_level_ == 4 && block_size == 32 && batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 && - !has_zero_points && use_dp4a && M >= kMinMForTileOptimization) { + !has_zero_points && use_dp4a) { constexpr uint32_t kVec4Components = 4; constexpr uint32_t kVec2Components = 2; constexpr uint32_t kU32Components = 4; @@ -849,24 +1026,43 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); - constexpr uint32_t kTileSize = 64; - TensorShape reshaped_y_shape{1, M, N / kVec4Components}; - DP4AMatMulNBitsProgram mul_program; - mul_program.SetWorkgroupSize(256); - mul_program.SetDispatchGroupSize( - (M + kTileSize - 1) / kTileSize, - (N + kTileSize - 1) / kTileSize, 1); - mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, - {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, - {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, - {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) - .AddUniformVariables({{static_cast(M)}, - {static_cast(N)}, - {static_cast(K)}, - {static_cast(K / 8)}, - {static_cast(K / 16)}}) - .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}); - return context.RunProgram(mul_program); + if (M >= kMinMForTileOptimization) { + constexpr uint32_t kTileSize = 64; + TensorShape reshaped_y_shape{1, M, N / kVec4Components}; + DP4AMatMulNBitsProgram mul_program; + mul_program.SetWorkgroupSize(256); + mul_program.SetDispatchGroupSize( + (M + kTileSize - 1) / kTileSize, + (N + kTileSize - 1) / kTileSize, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec2Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 8)}, + {static_cast(K / 16)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow(kVec4Components)}); + return context.RunProgram(mul_program); + } else { + constexpr uint32_t kTileSize = 16; + DP4AMatMulNBitsSmallMProgram mul_program; + mul_program.SetWorkgroupSize(64); + mul_program.SetDispatchGroupSize( + (N + kTileSize - 1) / kTileSize, M, 1); + mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}, + {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}, + {b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components * kU32Components)}, + {scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(1)}}) + .AddUniformVariables({{static_cast(M)}, + {static_cast(N)}, + {static_cast(K)}, + {static_cast(K / 16)}, + {static_cast(K / 32)}}) + .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(4)}); + return context.RunProgram(mul_program); + } } // TODO: Support output_number > 1. Some cases are failed when output_number > 1. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index a2470d9268907..2dee8fbd2af09 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -53,6 +53,18 @@ class DP4AMatMulNBitsProgram final : public Program { {"K16", ProgramUniformVariableDataType::Uint32}); }; +class DP4AMatMulNBitsSmallMProgram final : public Program { + public: + DP4AMatMulNBitsSmallMProgram() : Program{"DP4AMatMulNBitsSmallMProgram"} {} + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"K16", ProgramUniformVariableDataType::Uint32}, + {"K32", ProgramUniformVariableDataType::Uint32}); +}; + class MatMulNBits final : public WebGpuKernel { public: MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {