From 898094250ee033255d64375115b439bb6b233063 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 18 Mar 2025 10:02:19 +0800 Subject: [PATCH] [webgpu] Limit that K must be divisible by 128 to apply dp4a matmul --- .../webgpu/quantization/dp4a_matmul_nbits.cc | 16 +++------------- .../webgpu/quantization/dp4a_matmul_nbits.h | 1 - .../test/contrib_ops/matmul_4bits_test.cc | 2 ++ 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 05cbfb1f99c48..65807b072bc80 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -12,21 +12,12 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddOutput("output", ShaderUsage::UseUniform); shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << R"ADDNL_FN( - fn readInput(offset: u32) -> input_a_value_t - { - if (offset > uniforms.input_size) { - return input_a_value_t(0); - } - return input_a[offset]; - } - )ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( var local_a : array, 32>; var max_value:vec4 = vec4(0); for (var idx:u32=0;idx<32;idx+=1) { - local_a[idx] = readInput(workgroup_idx*32 + idx); + local_a[idx] = input_a[workgroup_idx*32 + idx]; max_value = max(max_value, abs(local_a[idx])); } var scale = max(max_value.x, max_value.y); @@ -279,8 +270,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}}) .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), 1}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}}) - .AddUniformVariable({static_cast(M * K / kVec4Components)}); + {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); constexpr uint32_t kTileSize = 64; @@ -317,7 +307,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context, bool use_dp4a = context.Device().HasFeature(wgpu::FeatureName::Subgroups) && context.AdapterInfo().backendType != wgpu::BackendType::Metal; return (accuracy_level == 4 && block_size % 32 == 0 && - batch_count == 1 && components_k == 4 && K % 64 == 0 && N % 16 == 0 && + batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 && !has_zero_points && use_dp4a); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 15b86d78301ad..7e4a8f5d68437 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -16,7 +16,6 @@ class DP4AMatMulQuantizeProgram final : public Program { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index b1779ded4a675..8187253311ed3 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -389,6 +389,7 @@ TEST(MatMulNBits, Float32_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); @@ -458,6 +459,7 @@ TEST(MatMulNBits, Float16_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped();