diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 60e98a60741..f4c5eca0df5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -84,16 +84,16 @@ struct ggml_webgpu_shader_lib_context { ggml_tensor * src5; ggml_tensor * dst; - uint32_t max_wg_size; - size_t wg_mem_limit_bytes = 0; - bool supports_subgroups = false; - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m = 0; - uint32_t sg_mat_n = 0; - uint32_t sg_mat_k = 0; - uint32_t min_subgroup_size = 0; - uint32_t max_subgroup_size = 0; - bool supports_dot_product = false; + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool supports_subgroups = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; + uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; std::string vendor; }; @@ -166,9 +166,11 @@ struct ggml_webgpu_set_rows_pipeline_key { int dst_type; int vec4; int i64_idx; + int pair_blocks; bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { - return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx && + pair_blocks == other.pair_blocks; } }; @@ -178,6 +180,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.vec4); ggml_webgpu_hash_combine(seed, key.i64_idx); + ggml_webgpu_hash_combine(seed, key.pair_blocks); return seed; } }; @@ -185,6 +188,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { struct ggml_webgpu_set_rows_shader_decisions { bool vec4; bool i64_idx; + bool pair_blocks; uint32_t wg_size; }; @@ -772,31 +776,30 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && - context.src2->ne[0] % kv_vec_head_align == 0; + const uint32_t kv_vec_head_align = + K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = + context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0; // Compile with enough invocations to cover the largest reported subgroup. - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && - kv_vec_head_dims_aligned && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_subgroup_matrix = - context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && - context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; + const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && + context.src0->ne[0] % context.sg_mat_k == 0 && + context.src2->ne[0] % context.sg_mat_n == 0; const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && tile_can_dispatch_all_q_rows && !use_vec; - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { return decisions; @@ -1131,9 +1134,9 @@ class ggml_webgpu_shader_lib { ggml_webgpu_flash_attn_blk_pipeline_key_hash> flash_attn_blk_pipelines; std::unordered_map - mul_mat_vec_pipelines; // fast mat-vec (n==1) + mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) std::unordered_map quantize_q8_pipelines; std::unordered_map mul_mat_id_gather_pipelines; // key is fixed @@ -1264,10 +1267,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = {}; - key.dst_type = context.dst->type; - key.vec4 = context.src0->ne[0] % 4 == 0; - key.i64_idx = context.src1->type == GGML_TYPE_I64; + const bool quantized = ggml_is_quantized(context.dst->type); + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = + (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; + key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0); auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -1286,6 +1292,14 @@ class ggml_webgpu_shader_lib { defines.push_back("DST_F16"); variant += "_dstf16"; break; + case GGML_TYPE_Q8_0: + defines.push_back("DST_Q8_0"); + variant += "_dstq8_0"; + break; + case GGML_TYPE_Q4_0: + defines.push_back("DST_Q4_0"); + variant += "_dstq4_0"; + break; default: GGML_ABORT("Unsupported dst type for set_rows shader"); } @@ -1298,13 +1312,19 @@ class ggml_webgpu_shader_lib { defines.push_back("I64_IDX"); variant += "_i64idx"; } + if (key.pair_blocks) { + defines.push_back("PAIR_BLOCKS"); + variant += "_pair_blocks"; + } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_set_rows, defines); - auto decisions = std::make_shared(); + const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows; + auto processed = preprocessor.preprocess(shader_source, defines); + auto decisions = std::make_shared(); decisions->vec4 = key.vec4; decisions->i64_idx = key.i64_idx; + decisions->pair_blocks = key.pair_blocks; decisions->wg_size = context.max_wg_size; set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); set_rows_pipelines[key].context = decisions; @@ -1660,7 +1680,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1819,7 +1839,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; - key.use_mmvq = + key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f113da909ce..b37521ded84 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1334,7 +1334,11 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct } uint32_t threads; - if (decisions->vec4) { + if (ggml_is_quantized(dst->type)) { + const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type); + threads = + (src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row); + } else if (decisions->vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; @@ -4045,8 +4049,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); break; case GGML_OP_SET_ROWS: - supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); + supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 || + op->type == GGML_TYPE_Q4_0) && + src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); break; case GGML_OP_GET_ROWS: if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 99e9192c71a..09f2f0eddb3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -71,7 +71,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { return; } - // getting the row from gid let elems_per_row = params.ne0 / VEC_SIZE; var i = gid.x / elems_per_row; @@ -104,6 +103,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - let col_idx = (gid.x % elems_per_row); - dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]); + let col_idx = gid.x % elems_per_row; + dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl new file mode 100644 index 00000000000..876e65b6ae1 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl @@ -0,0 +1,224 @@ +#ifdef DST_Q8_0 +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 34u +#define QS_WORDS 8u +#elif defined(DST_Q4_0) +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 18u +#define QS_WORDS 4u +#endif + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var idx: array; + +@group(0) @binding(2) +#ifdef PAIR_BLOCKS +var dst: array; +#else +var dst: array>; +#endif + +#ifdef I64_IDX +@group(0) @binding(3) +var error: atomic; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in blocks + + // Strides (in elements / blocks) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var params: Params; + +// if the quantization type is unaligned and there are an odd number of blocks per row, we need to store atomically +#ifndef PAIR_BLOCKS +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + loop { + let old = atomicLoad(&dst[word_idx]); + let merged = (old & ~mask) | (bits & mask); + let result = atomicCompareExchangeWeak(&dst[word_idx], old, merged); + if (result.exchanged) { + return; + } + } +} +#else +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + let old = dst[word_idx]; + dst[word_idx] = (old & ~mask) | (bits & mask); +} +#endif + +fn store_u16(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 2u) * 8u; + let mask = 0xFFFFu << shift; + merge_store_dst_word(word_idx, mask, (value & 0xFFFFu) << shift); +} + +fn store_u32(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 3u) * 8u; + + if (shift == 0u) { +#ifdef PAIR_BLOCKS + dst[word_idx] = value; +#else + atomicStore(&dst[word_idx], value); +#endif + return; + } + + let lo_mask = 0xFFFFFFFFu << shift; + let hi_mask = (1u << shift) - 1u; + merge_store_dst_word(word_idx, lo_mask, value << shift); + merge_store_dst_word(word_idx + 1u, hi_mask, value >> (32u - shift)); +} + +fn quantize_block_params(src_block: u32) -> vec2 { +#ifdef DST_Q8_0 + var amax = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + amax = max(amax, abs(src[src_block + j])); + } + + let d = amax / 127.0; + let id = select(0.0, 1.0 / d, d > 0.0); + return vec2(d, id); +#elif defined(DST_Q4_0) + var amax = 0.0; + var max_val = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + let v = src[src_block + j]; + let av = abs(v); + if (amax < av) { + amax = av; + max_val = v; + } + } + + let d = max_val / -8.0; + let id = select(0.0, 1.0 / d, d != 0.0); + return vec2(d, id); +#endif +} + +fn quantize_block_word(src_block: u32, j: u32, id: f32) -> u32 { +#ifdef DST_Q8_0 + let base = src_block + j * 4u; + return (u32(i32(round(src[base + 0u] * id)) & 0xFF) << 0u) | + (u32(i32(round(src[base + 1u] * id)) & 0xFF) << 8u) | + (u32(i32(round(src[base + 2u] * id)) & 0xFF) << 16u) | + (u32(i32(round(src[base + 3u] * id)) & 0xFF) << 24u); +#elif defined(DST_Q4_0) + var packed_q = 0u; + for (var k: u32 = 0u; k < 4u; k++) { + let x0 = src[src_block + j * 4u + k] * id; + let x1 = src[src_block + 16u + j * 4u + k] * id; + let q0 = u32(clamp(i32(x0 + 8.5), 0, 15)); + let q1 = u32(clamp(i32(x1 + 8.5), 0, 15)); + packed_q |= (q0 & 0xFu) << (8u * k); + packed_q |= (q1 & 0xFu) << (8u * k + 4u); + } + return packed_q; +#endif +} + +fn quantize_block(src_block: u32, dst_word_idx: u32, block_byte_offset: u32) { + let params = quantize_block_params(src_block); + let d = params.x; + let id = params.y; + let packed_d = pack2x16float(vec2(d, 0.0)) & 0xFFFFu; + store_u16(dst_word_idx, block_byte_offset, 0u, packed_d); + + for (var j: u32 = 0u; j < QS_WORDS; j++) { + store_u32(dst_word_idx, block_byte_offset, 2u + j * 4u, quantize_block_word(src_block, j, id)); + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + let blocks_per_row = params.ne0 / BLOCK_SIZE; +#ifdef PAIR_BLOCKS + let blocks_per_invocation = 2u; +#else + let blocks_per_invocation = 1u; +#endif + let invocations_per_row = blocks_per_row / blocks_per_invocation; + let total_invocations = params.ne3 * params.ne2 * params.n_rows * invocations_per_row; + if (gid.x >= total_invocations) { + return; + } + + var i = gid.x / invocations_per_row; + let block_in_row = (gid.x % invocations_per_row) * blocks_per_invocation; + + let i_src3 = i / (params.ne2 * params.n_rows); + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + +#ifdef I64_IDX + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2u; + let idx_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1u]; + + if (idx_low_val != 0u) { + atomicStore(&error, 1u); + return; + } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif + + let dst_row_blocks = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + let src_block = src_row + block_in_row * BLOCK_SIZE; + let dst_block_byte = (dst_row_blocks + block_in_row) * BLOCK_BYTES; + + let dst_word_idx = dst_block_byte / 4u; +#ifdef PAIR_BLOCKS + quantize_block(src_block, dst_word_idx, 0u); + quantize_block(src_block + BLOCK_SIZE, dst_word_idx, BLOCK_BYTES); +#else + quantize_block(src_block, dst_word_idx, dst_block_byte & 3u); +#endif +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3853f03297b..7c62338e66f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2415,6 +2415,15 @@ struct test_set_rows : public test_case { } return 1e-7; } + + // See dicussion here: https://github.com/ggml-org/llama.cpp/pull/23760#issuecomment-4566312209 + double max_nmse_err(ggml_backend_t backend) override { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend)); + if (type == GGML_TYPE_Q8_0 && strcmp(ggml_backend_reg_name(reg), "WebGPU") == 0) { + return std::max(test_case::max_nmse_err(backend), 2e-7); + } + return test_case::max_nmse_err(backend); + } }; // GGML_OP_ROPE + GGML_OP_VIEW + GGML_OP_SET_ROWS