Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 215 additions & 19 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,183 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

// tile_N size = 16, workgroup size = 64, scale_A components = 4, b components = 4, output components = 4
Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddInput("scales_a", ShaderUsage::UseUniform);
shader.AddInput("input_b", ShaderUsage::UseUniform);
shader.AddInput("scales_b", ShaderUsage::UseUniform);
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias);

shader.AdditionalImplementation() << R"ADDNL_FN(
const tile_size = 16u; // tile_size = tile_size_vec * output components
const tile_size_vec = 4u;
const tile_size_k_vec = 16u; // tile_size_vec * tile_size_k_vec = workgroup size

// Shared memory
var<workgroup> tile_A : array<vec4<u32>, 32>; // 256
var<workgroup> scale_A : vec4<output_element_t>; // 4
var<workgroup> inter_results: array<array<vec4<output_element_t>, tile_size_k_vec>, tile_size_vec>;

fn loadSHMA(a_global:u32, kidx_v:u32, col: u32)
{
let k_offset = kidx_v + col;
if (k_offset > uniforms.K16)
{
return;
}
tile_A[col] = input_a[a_global*uniforms.K16+k_offset];
if (col == 0)
{
// kidx_v - covers 16 values of k
scale_A = scales_a[a_global*(uniforms.K/512) + k_offset/32];
}
}

// Scaled dot product of 8 packed unsigned integers.
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
{
var local_sum = dot4I8Packed(a1[0], b1[0]);
local_sum += dot4I8Packed(a1[1], b1[1]);
local_sum += dot4I8Packed(a1[2], b1[2]);
local_sum += dot4I8Packed(a1[3], b1[3]);
local_sum += dot4I8Packed(a2[0], b2[0]);
local_sum += dot4I8Packed(a2[1], b2[1]);
local_sum += dot4I8Packed(a2[2], b2[2]);
local_sum += dot4I8Packed(a2[3], b2[3]);
return output_element_t(local_sum) * scale;
}
)ADDNL_FN";

shader.MainFunctionBody() << R"MAIN_FN(
let a_global = workgroup_id.y;
let b_global_base = workgroup_id.x * tile_size;

let idx = local_idx % tile_size_k_vec;
let idy = local_idx / tile_size_k_vec;

for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v+=16)
{
// Load Phase: Populate shared memory for the workgroup.
if (local_idx < 32)
{
loadSHMA(a_global, kidx_v * 2, local_idx);
}
workgroupBarrier();

var own_a: vec4<u32> = tile_A[idx*2];
var own_a1: vec4<u32> = tile_A[idx*2 + 1];
var own_scale_a: output_element_t = scale_A[idx / 4];

var own_b = vec4<u32>(0);
var own_b1 = vec4<u32>(0);
var own_scale_b = output_element_t(0);
let b_global = b_global_base + idy * 4;
let k_offset = kidx_v+idx;
if (b_global < uniforms.N && k_offset < uniforms.K32)
{ var b_offset = b_global*uniforms.K32+k_offset;
var b_value = input_b[b_offset];
var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
own_scale_b = scales_b[b_offset];
inter_results[idy][idx].x += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);

b_offset = (b_global + 1)*uniforms.K32+k_offset;
b_value = input_b[b_offset];
b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
own_scale_b = scales_b[b_offset];
inter_results[idy][idx].y += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);

b_offset = (b_global + 2)*uniforms.K32+k_offset;
b_value = input_b[b_offset];
b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
own_scale_b = scales_b[b_offset];
inter_results[idy][idx].z += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);

b_offset = (b_global + 3)*uniforms.K32+k_offset;
b_value = input_b[b_offset];
b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[2] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[2] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
b_value_lower = vec4<i32>(unpack4xU8(b_value[3] & 0x0F0F0F0Fu)) - vec4<i32>(8);
b_value_upper = vec4<i32>(unpack4xU8((b_value[3] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
own_b1[2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
own_b1[3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
own_scale_b = scales_b[b_offset];
inter_results[idy][idx].w += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
}
}

workgroupBarrier();
if (local_idx < tile_size_vec) {
var output_value = vec4<output_element_t>(0);
for (var b = 0u; b < tile_size_k_vec; b++) {
output_value += inter_results[local_idx][b];
}
let b_global = b_global_base + local_idx * 4;
let output_idx = (a_global * uniforms.N + b_global)/4;
if (b_global < uniforms.N) {
output[output_idx] = output_value;
}
}
)MAIN_FN";

return Status::OK();
}

Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* a = context.Input(0);
const Tensor* b = context.Input(1);
Expand Down Expand Up @@ -831,7 +1008,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
const bool use_dp4a = has_subgroup && context.AdapterInfo().backendType != wgpu::BackendType::Metal;
if (accuracy_level_ == 4 && block_size == 32 &&
batch_count == 1 && components_a == 4 && K % 64 == 0 && N % 16 == 0 &&
!has_zero_points && use_dp4a && M >= kMinMForTileOptimization) {
!has_zero_points && use_dp4a) {
constexpr uint32_t kVec4Components = 4;
constexpr uint32_t kVec2Components = 2;
constexpr uint32_t kU32Components = 4;
Expand All @@ -849,24 +1026,43 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow<int>(1)}});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

constexpr uint32_t kTileSize = 64;
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
DP4AMatMulNBitsProgram mul_program;
mul_program.SetWorkgroupSize(256);
mul_program.SetDispatchGroupSize(
(M + kTileSize - 1) / kTileSize,
(N + kTileSize - 1) / kTileSize, 1);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K / 8)},
{static_cast<uint32_t>(K / 16)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(kVec4Components)});
return context.RunProgram(mul_program);
if (M >= kMinMForTileOptimization) {
constexpr uint32_t kTileSize = 64;
TensorShape reshaped_y_shape{1, M, N / kVec4Components};
DP4AMatMulNBitsProgram mul_program;
mul_program.SetWorkgroupSize(256);
mul_program.SetDispatchGroupSize(
(M + kTileSize - 1) / kTileSize,
(N + kTileSize - 1) / kTileSize, 1);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K / 8)},
{static_cast<uint32_t>(K / 16)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(kVec4Components)});
return context.RunProgram(mul_program);
} else {
constexpr uint32_t kTileSize = 16;
DP4AMatMulNBitsSmallMProgram mul_program;
mul_program.SetWorkgroupSize(64);
mul_program.SetDispatchGroupSize(
(N + kTileSize - 1) / kTileSize, M, 1);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
.AddUniformVariables({{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{static_cast<uint32_t>(K / 16)},
{static_cast<uint32_t>(K / 32)}})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)});
return context.RunProgram(mul_program);
}
}

// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
{"K16", ProgramUniformVariableDataType::Uint32});
};

class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMProgram> {
public:
DP4AMatMulNBitsSmallMProgram() : Program{"DP4AMatMulNBitsSmallMProgram"} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
{"K16", ProgramUniformVariableDataType::Uint32},
{"K32", ProgramUniformVariableDataType::Uint32});
};

class MatMulNBits final : public WebGpuKernel {
public:
MatMulNBits(const OpKernelInfo& info) : WebGpuKernel(info) {
Expand Down
Loading