Skip to content
185 changes: 162 additions & 23 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,39 @@
namespace onnxruntime {
namespace contrib {
namespace webgpu {
namespace {

constexpr std::string_view commonFunctions = R"ADDNL_FN(
fn DequantizedFrom4BitsTo8Bits(in: vec2<u32>) -> vec4<u32>
{
var out = vec4<u32>(0);
var value_lower = vec4<i32>(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
var value_upper = vec4<i32>(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
out[0] = pack4xI8(vec4<i32>(value_lower[0], value_upper[0], value_lower[1], value_upper[1]));
out[1] = pack4xI8(vec4<i32>(value_lower[2], value_upper[2], value_lower[3], value_upper[3]));
value_lower = vec4<i32>(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
value_upper = vec4<i32>(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
out[2] = pack4xI8(vec4<i32>(value_lower[0], value_upper[0], value_lower[1], value_upper[1]));
out[3] = pack4xI8(vec4<i32>(value_lower[2], value_upper[2], value_lower[3], value_upper[3]));
return out;
}

// Scaled dot product of 8 packed unsigned integers.
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
{
var local_sum = dot4I8Packed(a1[0], b1[0]);
local_sum += dot4I8Packed(a1[1], b1[1]);
local_sum += dot4I8Packed(a1[2], b1[2]);
local_sum += dot4I8Packed(a1[3], b1[3]);
local_sum += dot4I8Packed(a2[0], b2[0]);
local_sum += dot4I8Packed(a2[1], b2[1]);
local_sum += dot4I8Packed(a2[2], b2[2]);
local_sum += dot4I8Packed(a2[3], b2[3]);
return output_element_t(local_sum) * scale;
}
)ADDNL_FN";

} // namespace

Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
Expand Down Expand Up @@ -65,7 +98,8 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
// this shader require A to be int8 quantized with block size 64. B is regular
// matmulnbits input with block size 32.

shader.AdditionalImplementation() << " const block_size = " << block_size_ << ";";
shader.AdditionalImplementation() << commonFunctions
<< " const block_size = " << block_size_ << ";";

shader.AdditionalImplementation() << R"ADDNL_FN(
const tile_size = 64;
Expand Down Expand Up @@ -105,34 +139,13 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
tile_B[col][row][2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
tile_B[col][row][3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value);
if (col == 0)
{
// kidx_v - each kidx_v covers 16 values of k
scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + kidx_v/(block_size/16)];
}
}

// Scaled dot product of 8 packed unsigned integers.
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
{
var local_sum = dot4I8Packed(a1[0], b1[0]);
local_sum += dot4I8Packed(a1[1], b1[1]);
local_sum += dot4I8Packed(a1[2], b1[2]);
local_sum += dot4I8Packed(a1[3], b1[3]);
local_sum += dot4I8Packed(a2[0], b2[0]);
local_sum += dot4I8Packed(a2[1], b2[1]);
local_sum += dot4I8Packed(a2[2], b2[2]);
local_sum += dot4I8Packed(a2[3], b2[3]);
return output_element_t(local_sum) * scale;
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
Expand Down Expand Up @@ -249,11 +262,122 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

// scale_A components = 1, b components = 4, output components = 1
Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform);
shader.AddInput("scales_a", ShaderUsage::UseUniform);
shader.AddInput("input_b", ShaderUsage::UseUniform);
shader.AddInput("scales_b", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);
// This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads,
// and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused).
// For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows
// in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a.

// 1. Each workgroup handles tile_size_k_vec (16) * k_vectorization_in_b (32) columns (total 512) and num_concurrent_b_rows of matrix B at a time,
// iterating over the columns to compute a partial dot product.
// 2. Uses vec4 vectorization where each K represents 32 elements of matrix B
constexpr uint32_t tile_size_k_vec = 16;

// 1. Workgroup Responsibility:
// - Processes one row of matrix A
// - Handles tile_size rows of matrix B
//
// 2. Computation Process:
// - Reads [tile_size][tile_size_k_vec] block of B data at a time
// - Each thread within workgroup computes dot products of 32 A*B elements since each K represents 32 elements of matrix B
// - Stores intermediate results in shared memory (inter_results)
// - Iterates through columns accumulating results in inter_results
// - Performs final reduction sum in inter_results for output
shader.AdditionalImplementation() << "const tile_size = " << tile_size_ << "u;\n"
<< "const tile_size_k_vec = " << tile_size_k_vec << "u;\n"
// sub_tile_size is the number of concurrent b rows processed by the workgroup.
<< "const sub_tile_size = " << WorkgroupSizeX() / tile_size_k_vec << "u;\n";
shader.AdditionalImplementation() << commonFunctions
<< R"ADDNL_FN(
// Shared memory
// Need 2 * tile_size_k_vec (32) to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits.
var<workgroup> tile_A : array<vec4<u32>, 32>;
// Need 4 scales value since each tile_A includes 512 (4x4x32) scalars and the block_size is 128.
var<workgroup> scale_A : array<output_element_t, 4>;
var<workgroup> inter_results: array<array<output_element_t, tile_size_k_vec>, tile_size>;
fn loadSHMA(a_global: u32, kidx_v: u32, col: u32)
{
let k_offset = kidx_v + col;
if (k_offset >= uniforms.K16) {
return;
}

tile_A[col] = input_a[a_global*uniforms.K16+k_offset];
if (col < 4)
{
// kidx_v - covers 16 values of k in input_a
scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col];
}
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
let a_global = u32(workgroup_idx / uniforms.num_N_tile);
let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
// Handle each workgroup threads as a block of [sub_tile_size][tile_size_k_vec]
let local_col = local_idx % tile_size_k_vec;
let local_row = local_idx / tile_size_k_vec;
for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec)
{
// Load Phase: Populate shared memory for the workgroup.
if (local_idx < 32)
{
loadSHMA(a_global, kidx_v * 2, local_idx);
}
workgroupBarrier();
var own_a: vec4<u32> = tile_A[local_col * 2];
var own_a1: vec4<u32> = tile_A[local_col * 2 + 1];
var own_scale_a = scale_A[local_col / 4];
var own_b = vec4<u32>(0);
var own_b1 = vec4<u32>(0);
let k_offset = kidx_v + local_col;
// calculate intermediate results into inter_results.
for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_size) {
let b_global = b_global_base + row_offset + local_row;
if (b_global < uniforms.N && k_offset < uniforms.K32)
{
let b_offset = b_global * uniforms.K32 + k_offset;
let b_value = input_b[b_offset];
own_b = DequantizedFrom4BitsTo8Bits(b_value.xy);
own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw);

// k_offset - covers 32 values of k in input_b
let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + k_offset * 32 / uniforms.block_size];
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
}
}
workgroupBarrier();
}

if (local_idx < tile_size) {
// Do reduce sum to get final output.
var output_value = output_element_t(0);
for (var b = 0u; b < tile_size_k_vec; b++) {
output_value += inter_results[local_idx][b];
}
let b_global = b_global_base + local_idx;
let output_idx = a_global * uniforms.N + b_global;
if (b_global < uniforms.N) {
output[output_idx] = output_value;
}
}
)MAIN_FN";

return Status::OK();
}

Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t block_size,
uint32_t min_M_for_tile_optimization,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y) {
constexpr uint32_t kVec4Components = 4;
Expand All @@ -273,6 +397,21 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), 1}});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

if (M < min_M_for_tile_optimization) {
constexpr uint32_t kTileSize = 32;
DP4AMatMulNBitsSmallMProgram mul_program{kTileSize};
uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize;
mul_program.SetWorkgroupSize(128);
mul_program.SetDispatchGroupSize(M * num_N_tile);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
.AddUniformVariables({M, N, K, K / 16, K / 32, block_size, num_N_tile})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1});
return context.RunProgram(mul_program);
}

constexpr uint32_t kTileSize = 64;
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
DP4AMatMulNBitsProgram mul_program{block_size};
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,29 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
uint32_t block_size_;
};

class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMProgram> {
public:
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size) : Program{"DP4AMatMulNBitsSmallMProgram"}, tile_size_(tile_size) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K16", ProgramUniformVariableDataType::Uint32},
{"K32", ProgramUniformVariableDataType::Uint32},
{"block_size", ProgramUniformVariableDataType::Uint32},
{"num_N_tile", ProgramUniformVariableDataType::Uint32});

private:
uint32_t tile_size_;
};

Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
uint32_t M,
uint32_t N,
uint32_t K,
uint32_t block_size,
uint32_t min_M_for_tile_optimization,
onnxruntime::webgpu::ComputeContext& context,
Tensor* y);

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
return ApplySubgroupMatrixMatMulNBits(a, b, scales, M, N, K, context, y);
}

if (M >= kMinMForTileOptimization &&
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, context, y);
// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>()) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level_, block_size, batch_count, N, K, components_a, has_zero_points)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, M, N, K, block_size, kMinMForTileOptimization, context, y);
}

// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
Expand Down
Loading