Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ class ggml_webgpu_shader_lib {
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);

switch (key.src_type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q8_0:
Expand Down Expand Up @@ -1134,7 +1135,9 @@ class ggml_webgpu_shader_lib {

defines.push_back("DST_TYPE=f32");

if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
if (key.src_type == GGML_TYPE_Q1_0) {
defines.push_back("BLOCK_SIZE=128u");
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
key.src_type == GGML_TYPE_IQ4_NL) {
defines.push_back("BLOCK_SIZE=32u");
} else if (key.src_type >= GGML_TYPE_Q2_K) {
Expand Down Expand Up @@ -1403,7 +1406,9 @@ class ggml_webgpu_shader_lib {
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;

if (key.src0_type >= GGML_TYPE_Q2_K) {
if (key.src0_type == GGML_TYPE_Q1_0) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q1_0:
use_fast = true;
break;
default:
Expand Down Expand Up @@ -3323,6 +3324,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm

static bool ggml_webgpu_supported_qtype(ggml_type type) {
switch (type) {
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -3417,6 +3419,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down Expand Up @@ -3455,6 +3458,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
switch (src0->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
Expand Down
18 changes: 18 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
}
#endif

#ifdef Q1_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18;
let d = load_f16_as_f32_at_src(block_byte_base);
for (var j: u32 = 0u; j < 4u; j++) {
let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u);
let dst_base128 = dst_base + offset * 128u + j * 32u;
for (var k: u32 = 0; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
for (var bit: u32 = 0; bit < 8u; bit++) {
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
dst[dst_base128 + k * 8u + bit] = w;
}
}
}
}
#endif

#ifdef Q4_0
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
Expand Down
33 changes: 33 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
#endif // INIT_SRC1_SHMEM_FLOAT
#endif

#ifdef INIT_SRC0_SHMEM_Q1_0
const BLOCK_SIZE = 128u;
const BLOCK_SIZE_BYTES = 18u;
const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration

fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
let tile_m = i / TILE_K;
let tile_k_start = i % TILE_K;
let global_m = offset_m + tile_m;
let global_k_start = k_outer + tile_k_start;

if (global_m >= params.m) {
break;
}

let block_k = global_k_start / BLOCK_SIZE;
let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u;
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
let d = load_f16_at_src0(block_byte_base);
let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu;

for (var bit = 0u; bit < NQ; bit++) {
let global_k = global_k_start + bit;
if (global_k < params.k) {
shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
}
}
}
}
#endif // INIT_SRC0_SHMEM_Q1_0

#ifdef INIT_SRC0_SHMEM_Q4_0
const BLOCK_SIZE = 32u;
const BLOCK_SIZE_BYTES = 18u;
Expand Down
32 changes: 32 additions & 0 deletions ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,38 @@ fn main(
}
#endif

#ifdef MUL_ACC_Q1_0
#define BLOCK_SIZE 128
#define BLOCK_SIZE_BYTES 18
#define THREADS_PER_BLOCK 16
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)

let num_blocks = params.k / BLOCK_SIZE;
let thread_within_block = thread_id % THREADS_PER_BLOCK;
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
var x_block: array<f32, ELEMS_PER_THREAD>;
for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
x_block[i] = f32(src1[x_base + i]);
}

for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
let output_row = row_base + row;
if (output_row < params.m) {
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
let d = f32(load_f16_at_src0(block_byte_base));
let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
var row_sum = 0.0;
for (var bit = 0u; bit < 8u; bit++) {
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
row_sum += w * x_block[bit];
}
acc[row] += row_sum;
}
}
}
#endif

#ifdef MUL_ACC_Q4_0
#define BLOCK_SIZE 32
#define BLOCK_SIZE_BYTES 18
Expand Down
Loading