From 976ebc63b91a12cef359980c17694df2a8860528 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 25 Feb 2026 12:03:42 -0800 Subject: [PATCH 01/34] naive vectorized version --- .../wgsl-shaders/flash_attn_vec.wgsl | 584 ++++++++++++++++++ 1 file changed, 584 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl new file mode 100644 index 000000000000..6d2023e4658c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl @@ -0,0 +1,584 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +// Default values +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN +// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +// Each workgroup processes one subgroup matrix of Q rows +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 + +// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +// Quantization constants/helpers +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +// number of quantized elements processed per thread +#if defined(KV_Q4_0) +#define NQ 16 +// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +// Ok not to put these in a define block, compiler will remove if unused +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array>; + +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif + +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +// The number of Q rows processed per workgroup +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; // output shmem + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// storage for output of Q*K^T scores for online softmax (S matrix from paper) +// also storage for diagonal matrix during online softmax (P matrix from paper) +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + let mask_term = slope * mask_val; + v += mask_term; +#endif + return v; +} + +fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + // batch index + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + // clear inter_shmem to ensure zero-initialized accumulators + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } + + // load k tile into shared memory +// #if defined(KV_Q4_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let k_row = blck_idx / BLOCKS_K; +// let global_k_row = kv_tile + k_row; +// let block_k = blck_idx % BLOCKS_K; +// let row_offset = k_row * HEAD_DIM_QK; + +// if (global_k_row < params.seq_len_kv) { +// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = K[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = K[base_idx + 1u + block_offset + j]; +// let q_1 = K[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte(q_packed, k); +// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; +// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_lo; +// kv_shmem[row_offset + idx + 16u] = q_hi; +// } +// } +// } +// } +// #elif defined(KV_Q8_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let k_row = blck_idx / BLOCKS_K; +// let global_k_row = kv_tile + k_row; +// let block_k = blck_idx % BLOCKS_K; +// let row_offset = k_row * HEAD_DIM_QK; + +// if (global_k_row < params.seq_len_kv) { +// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = K[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = K[base_idx + 1u + block_offset + j]; +// let q_1 = K[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte_i32(q_packed, k); +// let q_val = f16(q_byte) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_val; +// } +// } +// } +// } +// #endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + // TODO: this loop seems to be the current largest bottleneck + // this bracket exists to scope the lifetime of variables, reducing register pressure + { + // vectorization + let num_of_threads = subgroup_size / 4u; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + continue; + } + + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += 4u) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = q_tile_row * HEAD_DIM_QK + i * 4u; + // let k_off = (kv_idx * HEAD_DIM_QK) + i * 4u; + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + + let qv = vec4(f32(q_shmem[q_off]), f32(q_shmem[q_off + 1u]), f32(q_shmem[q_off + 2u]), f32(q_shmem[q_off + 3u])); + // let kv = vec4(f32(kv_shmem[k_off]), f32(kv_shmem[k_off + 1u]), f32(kv_shmem[k_off + 2u]), f32(kv_shmem[k_off + 3u])); + let kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); + + partial_sum += dot(qv, kv); + + } + } + for (var g: u32 = 0u; g < 4u; g++) { + let kv_idx_g = kv_base + g; + let active_threads = (ty == g) && (kv_idx_g < KV_TILE) && ((kv_tile + kv_idx_g) < params.seq_len_kv); + + let contrib = select(0.0, partial_sum, active_threads); + let sum_g = subgroupAdd(contrib); + + if (tx == 0u && ty == g && kv_idx_g < KV_TILE) { + let dst_idx = q_tile_row * KV_TILE + kv_idx_g; + inter_shmem[dst_idx] = f16(select(FLOAT_MIN, sum_g, (kv_tile + kv_idx_g) < params.seq_len_kv)); + } + } + } + } + } + + +#ifdef MASK + // load mask tile into shared memory for this KV block + // TODO: optimize and skip if mask is -INF for the entire tile + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } +#endif + + workgroupBarrier(); + + // online softmax + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + // initialize running max for this row + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + + // load v tile into shared memory +// #if defined(KV_Q4_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let v_row = blck_idx / BLOCKS_V; +// let global_v_row = kv_tile + v_row; +// let block_k = blck_idx % BLOCKS_V; +// let row_offset = v_row * HEAD_DIM_V; + +// if (global_v_row < params.seq_len_kv) { +// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = V[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = V[base_idx + 1u + block_offset + j]; +// let q_1 = V[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte(q_packed, k); +// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; +// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_lo; +// kv_shmem[row_offset + idx + 16u] = q_hi; +// } +// } +// } +// } +// #elif defined(KV_Q8_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let v_row = blck_idx / BLOCKS_V; +// let global_v_row = kv_tile + v_row; +// let block_k = blck_idx % BLOCKS_V; +// let row_offset = v_row * HEAD_DIM_V; + +// if (global_v_row < params.seq_len_kv) { +// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = V[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = V[base_idx + 1u + block_offset + j]; +// let q_1 = V[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte_i32(q_packed, k); +// let q_val = f16(q_byte) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_val; +// } +// } +// } +// } +// #elif defined(KV_DIRECT) +// // Direct global loads for KV +// #else +// for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { +// let v_row = elem_idx / HEAD_DIM_V; +// let v_col = elem_idx % HEAD_DIM_V; +// let global_v_row = kv_tile + v_row; +// let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; +// kv_shmem[elem_idx] = f16(select( +// 0.0, +// V[global_v_row_offset + v_col], +// global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); +// } +// #endif + + workgroupBarrier(); + + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + // todo: load o_shmem + let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; + var acc = vec4(f32(o_shmem[o_base_idx]), f32(o_shmem[o_base_idx + 1u]), f32(o_shmem[o_base_idx + 2u]), f32(o_shmem[o_base_idx + 3u])); + for (var kv_idx : u32 = 0u; kv_idx < KV_TILE; kv_idx += 1u) { + let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + // let v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; + let v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; + let v4 = vec4(V[v_idx/4u]); + + acc += p * v4; + } + // todo: write acc back to o_shmem + o_shmem[o_base_idx] = f16(acc.x); + o_shmem[o_base_idx + 1u] = f16(acc.y); + o_shmem[o_base_idx + 2u] = f16(acc.z); + o_shmem[o_base_idx + 3u] = f16(acc.w); + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // add sinks (applied once after processing all KV tiles) + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + let val = f32(o_shmem[idx]) * max_exp; + o_shmem[idx] = f16(val); + } + } + workgroupBarrier(); +#endif + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } + + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + + let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } + } +} From 94abbac1ede4e91bb329af94e883ba5c65526612 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 3 Mar 2026 12:17:07 -0800 Subject: [PATCH 02/34] add vectorized flash attention --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 232 ++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 899 +++++++++++++++++- .../wgsl-shaders/flash_attn_pad.wgsl | 624 ++++++++++++ .../wgsl-shaders/flash_attn_vec_blk.wgsl | 94 ++ .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 78 ++ .../wgsl-shaders/flash_attn_vec_split.wgsl | 784 +++++++++++++++ 6 files changed, 2675 insertions(+), 36 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3d7e59fddf32..c8a7fadeab76 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + std::shared_ptr decisions; +}; + /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { @@ -275,11 +281,16 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; + bool use_vec; + bool use_pad; + bool use_vec_split; + bool use_blk; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec && + use_pad == other.use_pad && use_vec_split == other.use_vec_split && use_blk == other.use_blk; } }; @@ -293,6 +304,10 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.use_vec); + ggml_webgpu_hash_combine(seed, key.use_pad); + ggml_webgpu_hash_combine(seed, key.use_vec_split); + ggml_webgpu_hash_combine(seed, key.use_blk); return seed; } }; @@ -312,6 +327,99 @@ struct ggml_webgpu_flash_attn_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { + uint32_t head_dim_v; + uint32_t wg_size; +}; + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.wg_size); + return seed; + } +}; + +inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; +} + +struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + +struct ggml_webgpu_flash_attn_blk_pipeline_key { + uint32_t q_tile; + uint32_t kv_tile; + + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { + return q_tile == other.q_tile && kv_tile == other.kv_tile; + } +}; + +struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_tile); + ggml_webgpu_hash_combine(seed, key.kv_tile); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_blk_shader_lib_context { + ggml_webgpu_flash_attn_blk_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); + variant += std::string("_qt") + std::to_string(context.key.q_tile); + + defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); + variant += std::string("_kvt") + std::to_string(context.key.kv_tile); + + uint32_t wg_size = 1; + const uint32_t target_wg = std::min(32u, context.max_wg_size); + while ((wg_size << 1) <= target_wg) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -336,6 +444,128 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = + (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.key.kv_direct) { + bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); + } + if (context.key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn"; + + switch (context.key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(context.key.kv_type); + + if (context.key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (context.key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (context.key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + + if (context.key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + if (context.key.use_pad) { + defines.push_back("PAD"); + variant += "_pad"; + } + if (context.key.use_blk) { + defines.push_back("BLK"); + variant += "_blk"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + // For now these are not part of the variant name. + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + // Add chosen Q/KV tile sizes. + // Mirror Metal's vec path for split kernels: 1 query per workgroup and 32 KV cache values per subgroup. + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (context.key.use_vec_split) { + q_tile = 1; + kv_tile = 32; + GGML_ASSERT(kv_tile % context.sg_mat_n == 0); + GGML_ASSERT(kv_tile <= ggml_webgpu_flash_attn_max_kv_tile(context)); + } + if (context.key.kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + // Avoids bounds checks for direct KV loads. + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + uint32_t wg_size = 0; + if (context.key.use_vec_split) { + // Keep vec-split to a single subgroup; aligns lane mapping with Metal's vec kernel. + wg_size = context.max_subgroup_size; + } else { + wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + result.decisions = decisions; + return result; +} + /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 128b7dc3de8a..00aef8ed2056 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -86,7 +86,12 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool +<<<<<<< HEAD #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) +======= +#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE +#define WEBGPU_MAX_PARAM_BUFS_PER_CMD 4u +>>>>>>> 30923ffc9 (add vectorized flash attention) #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -269,12 +274,22 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { +<<<<<<< HEAD uint32_t num_kernels; wgpu::CommandBuffer commands; std::vector params_bufs; +======= + uint32_t num_kernels; + wgpu::CommandBuffer commands; + std::vector params_bufs; + std::optional set_rows_error_bufs; + // Keep temporary resources alive until submitted work is complete. + std::vector retained_buffers; +>>>>>>> 30923ffc9 (add vectorized flash attention) #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; + bool has_timestamp_query = false; #endif }; @@ -359,12 +374,21 @@ struct webgpu_context_struct { webgpu_global_context global_ctx; std::unique_ptr shader_lib; + pre_wgsl::Preprocessor p; webgpu_buf_pool param_buf_pool; wgpu::Buffer set_rows_dev_error_buf; wgpu::Buffer set_rows_host_error_buf; - std::map> cpy_pipelines; // src_type, dst_type + std::map> cpy_pipelines; // src_type, dst_type + std::unordered_map + flash_attn_pipelines; + std::unordered_map + flash_attn_vec_reduce_pipelines; + std::unordered_map + flash_attn_blk_pipelines; std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace @@ -592,8 +616,14 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & std::vector & commands, webgpu_buf_pool & param_buf_pool) { std::vector command_buffers; +<<<<<<< HEAD std::vector params_bufs; webgpu_submission submission; +======= + std::vector params_bufs; + std::vector set_rows_error_bufs; + std::vector retained_buffers; +>>>>>>> 30923ffc9 (add vectorized flash attention) #ifdef GGML_WEBGPU_GPU_PROFILE std::vector> pipeline_name_and_ts_bufs; #endif @@ -601,15 +631,23 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & for (const auto & command : commands) { command_buffers.push_back(command.commands); params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); +<<<<<<< HEAD +======= + retained_buffers.insert(retained_buffers.end(), command.retained_buffers.begin(), command.retained_buffers.end()); + if (command.set_rows_error_bufs) { + set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); + } +>>>>>>> 30923ffc9 (add vectorized flash attention) } ctx->queue.Submit(command_buffers.size(), command_buffers.data()); wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, - [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + [¶m_buf_pool, params_bufs, retained_buffers](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } + (void) retained_buffers; // Free the staged buffers param_buf_pool.free_bufs(params_bufs); }); @@ -617,6 +655,9 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & #ifdef GGML_WEBGPU_GPU_PROFILE for (const auto & command : commands) { + if (!command.has_timestamp_query) { + continue; + } auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; @@ -646,7 +687,13 @@ static webgpu_command ggml_backend_webgpu_build_multi( const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, +<<<<<<< HEAD const std::vector> & workgroups_list) { +======= + const std::vector> & workgroups_list, + const std::optional & set_rows_error_bufs = std::nullopt, + bool split_passes = false) { +>>>>>>> 30923ffc9 (add vectorized flash attention) GGML_ASSERT(pipelines.size() == params_list.size()); GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); @@ -657,10 +704,25 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < pipelines.size(); i++) { wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); +<<<<<<< HEAD std::vector entries = bind_group_entries_list[i]; uint32_t params_binding_num = entries.size(); entries.push_back( { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); +======= + std::vector entries = bind_group_entries_list[i]; + // Bindings can be sparse (e.g. 0,1,2,4,5), so params must use max(binding)+1. + uint32_t params_binding_num = 0; + for (const auto & entry : entries) { + if (entry.binding >= params_binding_num) { + params_binding_num = entry.binding + 1; + } + } + entries.push_back({ .binding = params_binding_num, + .buffer = params_bufs.dev_buf, + .offset = 0, + .size = params_bufs.dev_buf.GetSize() }); +>>>>>>> 30923ffc9 (add vectorized flash attention) wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -677,30 +739,52 @@ static webgpu_command ggml_backend_webgpu_build_multi( ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } + bool profile_pass = false; #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); + webgpu_gpu_profile_bufs ts_bufs = {}; + if (!split_passes) { + ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); + if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + ts_bufs.host_buf.Unmap(); + } + profile_pass = true; } +#endif +#ifndef GGML_WEBGPU_GPU_PROFILE + GGML_UNUSED(profile_pass); +#endif - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); + if (split_passes) { + for (size_t i = 0; i < pipelines.size(); i++) { + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipelines[i].pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + pass.End(); + } + } else { +#ifdef GGML_WEBGPU_GPU_PROFILE + wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, + .beginningOfPassWriteIndex = 0, + .endOfPassWriteIndex = 1 }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); #else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); #endif - for (size_t i = 0; i < pipelines.size(); i++) { - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + for (size_t i = 0; i < pipelines.size(); i++) { + pass.SetPipeline(pipelines[i].pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + } + pass.End(); } - pass.End(); #ifdef GGML_WEBGPU_GPU_PROFILE - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); + if (profile_pass) { + encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); + encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); + } #endif wgpu::CommandBuffer commands = encoder.Finish(); @@ -709,9 +793,12 @@ static webgpu_command ggml_backend_webgpu_build_multi( result.params_bufs = params_bufs_list; result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE - result.timestamp_query_bufs = ts_bufs; + if (profile_pass) { + result.timestamp_query_bufs = ts_bufs; + result.has_timestamp_query = true; + } // TODO: handle multiple pipeline names - result.pipeline_name = pipelines.front().name; + result.pipeline_name = pipelines.front().name; #endif return result; } @@ -731,6 +818,238 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & { { wg_x, wg_y } }); } +struct webgpu_buffer_clear_op { + wgpu::Buffer buffer; + uint64_t offset; + uint64_t size; +}; + +struct webgpu_buffer_copy_op { + wgpu::Buffer src; + uint64_t src_offset; + wgpu::Buffer dst; + uint64_t dst_offset; + uint64_t size; +}; + +static webgpu_command ggml_backend_webgpu_build_with_pre_ops( + webgpu_global_context & ctx, + webgpu_buf_pool & param_buf_pool, + webgpu_pipeline & pipeline, + const std::vector & params, + const std::vector bind_group_entries, + const std::vector & clear_ops, + const std::vector & copy_ops, + std::vector retained_buffers, + uint32_t wg_x, + uint32_t wg_y = 1) { + webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); + + ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); + uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); + for (size_t i = 0; i < params.size(); i++) { + _params[i] = params[i]; + } + params_bufs.host_buf.Unmap(); + + std::vector entries = bind_group_entries; + uint32_t params_binding_num = 0; + for (const auto & entry : entries) { + if (entry.binding >= params_binding_num) { + params_binding_num = entry.binding + 1; + } + } + entries.push_back({ .binding = params_binding_num, + .buffer = params_bufs.dev_buf, + .offset = 0, + .size = params_bufs.dev_buf.GetSize() }); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = pipeline.name.c_str(); + wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); + + for (const auto & op : clear_ops) { + encoder.ClearBuffer(op.buffer, op.offset, op.size); + } + for (const auto & op : copy_ops) { + encoder.CopyBufferToBuffer(op.src, op.src_offset, op.dst, op.dst_offset, op.size); + } + +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); + if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + ts_bufs.host_buf.Unmap(); + } + + wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, + .beginningOfPassWriteIndex = 0, + .endOfPassWriteIndex = 1 }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); +#else + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); +#endif + pass.SetPipeline(pipeline.pipeline); + pass.SetBindGroup(0, bind_group); + pass.DispatchWorkgroups(wg_x, wg_y, 1); + pass.End(); + +#ifdef GGML_WEBGPU_GPU_PROFILE + encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); + encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); +#endif + + webgpu_command result = {}; + result.commands = encoder.Finish(); + result.params_bufs = { params_bufs }; + result.retained_buffers = std::move(retained_buffers); +#ifdef GGML_WEBGPU_GPU_PROFILE + result.timestamp_query_bufs = ts_bufs; + result.pipeline_name = pipeline.name; +#endif + return result; +} + +static webgpu_command ggml_backend_webgpu_build_multi_with_pre_ops( + webgpu_global_context & ctx, + webgpu_buf_pool & param_buf_pool, + const std::vector & pipelines, + const std::vector> & params_list, + const std::vector> & bind_group_entries_list, + const std::vector> & workgroups_list, + const std::vector & clear_ops, + const std::vector & copy_ops, + std::vector retained_buffers, + const std::optional & set_rows_error_bufs = std::nullopt, + bool split_passes = false) { + GGML_ASSERT(pipelines.size() == params_list.size()); + GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); + GGML_ASSERT(pipelines.size() == workgroups_list.size()); + + std::vector params_bufs_list; + std::vector bind_groups; + + for (size_t i = 0; i < pipelines.size(); i++) { + webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); + + ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, + params_bufs.host_buf.GetSize()); + uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); + for (size_t j = 0; j < params_list[i].size(); j++) { + _params[j] = params_list[i][j]; + } + params_bufs.host_buf.Unmap(); + + std::vector entries = bind_group_entries_list[i]; + uint32_t params_binding_num = 0; + for (const auto & entry : entries) { + if (entry.binding >= params_binding_num) { + params_binding_num = entry.binding + 1; + } + } + entries.push_back({ .binding = params_binding_num, + .buffer = params_bufs.dev_buf, + .offset = 0, + .size = params_bufs.dev_buf.GetSize() }); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = pipelines[i].name.c_str(); + bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); + + params_bufs_list.push_back(params_bufs); + } + + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + for (const auto & params_bufs : params_bufs_list) { + encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); + } + + for (const auto & op : clear_ops) { + encoder.ClearBuffer(op.buffer, op.offset, op.size); + } + for (const auto & op : copy_ops) { + encoder.CopyBufferToBuffer(op.src, op.src_offset, op.dst, op.dst_offset, op.size); + } + + // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + if (set_rows_error_bufs) { + encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, + set_rows_error_bufs->host_buf.GetSize()); + } + + bool profile_pass = false; +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_gpu_profile_bufs ts_bufs = {}; + if (!split_passes) { + ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); + if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + ts_bufs.host_buf.Unmap(); + } + profile_pass = true; + } +#endif +#ifndef GGML_WEBGPU_GPU_PROFILE + GGML_UNUSED(profile_pass); +#endif + + if (split_passes) { + for (size_t i = 0; i < pipelines.size(); i++) { + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipelines[i].pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + pass.End(); + } + } else { +#ifdef GGML_WEBGPU_GPU_PROFILE + wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, + .beginningOfPassWriteIndex = 0, + .endOfPassWriteIndex = 1 }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); +#else + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); +#endif + for (size_t i = 0; i < pipelines.size(); i++) { + pass.SetPipeline(pipelines[i].pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + } + pass.End(); + } + +#ifdef GGML_WEBGPU_GPU_PROFILE + if (profile_pass) { + encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); + encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); + } +#endif + + wgpu::CommandBuffer commands = encoder.Finish(); + webgpu_command result = {}; + result.commands = commands; + result.params_bufs = params_bufs_list; + result.set_rows_error_bufs = set_rows_error_bufs; + result.retained_buffers = std::move(retained_buffers); +#ifdef GGML_WEBGPU_GPU_PROFILE + if (profile_pass) { + result.timestamp_query_bufs = ts_bufs; + result.has_timestamp_query = true; + } + result.pipeline_name = pipelines.front().name; +#endif + return result; +} + static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, wgpu::Buffer & buf, uint32_t value, @@ -1306,27 +1625,506 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = Q, - .src1 = K, - .src2 = V, - .src3 = mask, - .src4 = sinks, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, + bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + // Match Metal's vec-kernel shape heuristic. + const bool use_vec = (Q->ne[1] < 20) && + (Q->ne[0] % 32 == 0) && + (V->ne[0] % 4 == 0) && + (K->type == GGML_TYPE_F16); + + const uint32_t vec_nwg_cap = + std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_vec_split = use_vec && vec_nwg_cap > 1u; + const bool use_blk = use_vec_split && has_mask; + + // Compute vec KV tile (same logic as shader preprocessing) to decide whether tail padding is needed. + uint32_t kv_tile_for_vec = 0; + if (use_vec) { + ggml_webgpu_flash_attn_pipeline_key probe_key = { + .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + .use_vec = true, + .use_pad = false, + .use_vec_split = use_vec_split, + .use_blk = false, + }; + ggml_webgpu_flash_attn_shader_lib_context probe_ctx = { + .key = probe_key, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size + }; + kv_tile_for_vec = std::min(ggml_webgpu_flash_attn_max_kv_tile(probe_ctx), + probe_ctx.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (probe_key.kv_direct) { + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile_for_vec != 0) { + kv_tile_for_vec -= probe_ctx.sg_mat_n; + } + } + } + + const uint32_t tail_for_vec = (use_vec && kv_tile_for_vec > 0) ? (uint32_t) (K->ne[1] % kv_tile_for_vec) : 0u; + const bool copy_alignment_ok = + (ggml_webgpu_tensor_offset(K) % 4 == 0) && + (ggml_webgpu_tensor_offset(V) % 4 == 0) && + (K->nb[1] % 4 == 0) && + (V->nb[1] % 4 == 0) && + (!has_mask || ((ggml_webgpu_tensor_offset(mask) % 4 == 0) && + ((((uint64_t) K->ne[1]) * ggml_type_size(mask->type) % 4) == 0) && + ((((uint64_t) mask->nb[3]) % 4) == 0) && + (((uint64_t) tail_for_vec * ggml_type_size(mask->type)) % 4 == 0))); + const bool use_pad = use_vec && tail_for_vec != 0 && copy_alignment_ok; + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + .use_vec = use_vec, + .use_pad = use_pad, + .use_vec_split = use_vec_split, + .use_blk = use_blk, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline; + auto it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size + }; + + ggml_webgpu_processed_shader processed; + if (use_vec_split) { + processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_vec_split, shader_lib_ctx); + } else if (use_vec && use_pad) { + processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_pad, shader_lib_ctx); + } else if (use_vec) { + processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_vec, shader_lib_ctx); + } else { + processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); + } + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->flash_attn_pipelines.emplace(key, pipeline); + } auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + + uint32_t vec_nwg = 1u; + if (use_vec_split) { + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * vec_nwg * kv_span) < (uint64_t) K->ne[1] && vec_nwg < vec_nwg_cap) { + vec_nwg <<= 1; + } + vec_nwg = std::min(vec_nwg, vec_nwg_cap); + } + + bool have_pad_buf = false; + wgpu::Buffer pad_buf = {}; + uint64_t pad_size_bytes = 0; + uint32_t pad_k_base_u32 = 0; + uint32_t pad_v_base_u32 = 0; + uint32_t pad_m_base_u32 = 0; + uint32_t pad_ncpsg = 0; + std::vector pad_clear_ops; + std::vector pad_copy_ops; + + bool have_blk_buf = false; + wgpu::Buffer blk_buf = {}; + uint64_t blk_size_bytes = 0; + uint32_t blk_nblk0 = 0; + uint32_t blk_nblk1 = 0; + uint32_t blk_batch_count = 0; + + if (use_pad) { + const uint32_t ncpsg = decisions->kv_tile; + const uint32_t tail = (uint32_t) (K->ne[1] % ncpsg); + const uint32_t tail_start = (uint32_t) K->ne[1] - tail; + GGML_ASSERT(tail > 0); + + const uint32_t stride_k1 = (uint32_t) (K->nb[1] / ggml_type_size(K->type)); + const uint32_t stride_v1 = (uint32_t) (V->nb[1] / ggml_type_size(V->type)); + const uint32_t kv_heads = (uint32_t) K->ne[2]; + const uint32_t kv_batches = (uint32_t) K->ne[3]; + const uint64_t kv_planes = (uint64_t) kv_heads * kv_batches; + const uint32_t stride_mask3 = has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0u; + const uint32_t mask_batch_planes = has_mask ? (stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u) : 0u; + + const uint64_t pad_k_base = 0; + const uint64_t pad_v_base = pad_k_base + (uint64_t) stride_k1 * ncpsg * kv_planes; + const uint64_t pad_m_base = pad_v_base + (uint64_t) stride_v1 * ncpsg * kv_planes; + const uint64_t pad_m_elems = has_mask ? ((uint64_t) Q->ne[1] * ncpsg * mask_batch_planes) : 0u; + const uint64_t pad_total_elems = pad_m_base + pad_m_elems; + pad_size_bytes = + ROUNDUP_POW2(pad_total_elems * ggml_type_size(K->type), WEBGPU_STORAGE_BUF_BINDING_MULT); + + GGML_ASSERT(pad_k_base <= UINT32_MAX); + GGML_ASSERT(pad_v_base <= UINT32_MAX); + GGML_ASSERT(pad_m_base <= UINT32_MAX); + pad_k_base_u32 = (uint32_t) pad_k_base; + pad_v_base_u32 = (uint32_t) pad_v_base; + pad_m_base_u32 = (uint32_t) pad_m_base; + pad_ncpsg = ncpsg; + + ggml_webgpu_create_buffer(ctx->global_ctx->device, pad_buf, pad_size_bytes, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopyDst, + "flash_attn_pad_buf"); + have_pad_buf = true; + + pad_clear_ops = { + { .buffer = pad_buf, .offset = 0, .size = pad_size_bytes }, + }; + pad_copy_ops.clear(); + + const uint64_t k_tensor_off = ggml_webgpu_tensor_offset(K); + const uint64_t v_tensor_off = ggml_webgpu_tensor_offset(V); + const auto k_buf = ggml_webgpu_tensor_buf(K); + const auto v_buf = ggml_webgpu_tensor_buf(V); + + for (uint32_t b = 0; b < kv_batches; b++) { + for (uint32_t h = 0; h < kv_heads; h++) { + const uint64_t plane = (uint64_t) h + (uint64_t) b * kv_heads; + for (uint32_t r = 0; r < tail; r++) { + const uint64_t src_row_k = + k_tensor_off + (uint64_t) b * K->nb[3] + (uint64_t) h * K->nb[2] + (uint64_t) (tail_start + r) * K->nb[1]; + const uint64_t dst_row_k = + (pad_k_base + plane * stride_k1 * ncpsg + (uint64_t) r * stride_k1) * ggml_type_size(K->type); + pad_copy_ops.push_back({ .src = k_buf, + .src_offset = src_row_k, + .dst = pad_buf, + .dst_offset = dst_row_k, + .size = K->nb[1] }); + + const uint64_t src_row_v = + v_tensor_off + (uint64_t) b * V->nb[3] + (uint64_t) h * V->nb[2] + (uint64_t) (tail_start + r) * V->nb[1]; + const uint64_t dst_row_v = + (pad_v_base + plane * stride_v1 * ncpsg + (uint64_t) r * stride_v1) * ggml_type_size(V->type); + pad_copy_ops.push_back({ .src = v_buf, + .src_offset = src_row_v, + .dst = pad_buf, + .dst_offset = dst_row_v, + .size = V->nb[1] }); + } + } + } + + if (has_mask) { + const uint64_t mask_tensor_off = ggml_webgpu_tensor_offset(mask); + const auto mask_buf = ggml_webgpu_tensor_buf(mask); + const uint64_t mask_copy_size = (uint64_t) tail * ggml_type_size(mask->type); + for (uint32_t mb = 0; mb < mask_batch_planes; mb++) { + const uint32_t src_batch = + (stride_mask3 > 0 && mb >= (uint32_t) mask->ne[3]) ? (uint32_t) mask->ne[3] - 1 : mb; + for (uint32_t q = 0; q < (uint32_t) Q->ne[1]; q++) { + const uint64_t src_mask_elem = + (uint64_t) src_batch * stride_mask3 + (uint64_t) q * (uint32_t) K->ne[1] + tail_start; + const uint64_t dst_mask_elem = + pad_m_base + (uint64_t) mb * (uint32_t) Q->ne[1] * ncpsg + (uint64_t) q * ncpsg; + pad_copy_ops.push_back({ .src = mask_buf, + .src_offset = mask_tensor_off + src_mask_elem * ggml_type_size(mask->type), + .dst = pad_buf, + .dst_offset = dst_mask_elem * ggml_type_size(K->type), + .size = mask_copy_size }); + } + } + } + + if (!use_vec_split) { + std::vector pad_params = params; + pad_params.push_back(0u); // offset_pad + pad_params.push_back(pad_k_base_u32); // pad_k_base + pad_params.push_back(pad_v_base_u32); // pad_v_base + pad_params.push_back(pad_m_base_u32); // pad_m_base + pad_params.push_back(pad_ncpsg); // ncpsg + pad_params.push_back(1u); // nqptg + pad_params.push_back(1u); // nwg + + std::vector pad_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + { .binding = 3, .buffer = pad_buf, .offset = 0, .size = pad_size_bytes }, + }; + uint32_t pad_binding_index = 4; + if (has_mask) { + pad_entries.push_back({ .binding = pad_binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + pad_entries.push_back({ .binding = pad_binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + pad_entries.push_back({ .binding = pad_binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + return ggml_backend_webgpu_build_with_pre_ops(ctx->global_ctx, ctx->param_buf_pool, pipeline, pad_params, + pad_entries, pad_clear_ops, pad_copy_ops, { pad_buf }, wg_x); + } + } + + if (use_vec_split) { + const uint32_t nwg = vec_nwg; + GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_base = tmp_data_elems; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const uint64_t tmp_total_elems = tmp_data_elems + tmp_stats_elems; + const uint64_t tmp_size_bytes = + ROUNDUP_POW2(tmp_total_elems * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + GGML_ASSERT(nrows <= UINT32_MAX); + + wgpu::Buffer tmp_buf; + ggml_webgpu_create_buffer(ctx->global_ctx->device, tmp_buf, tmp_size_bytes, wgpu::BufferUsage::Storage, + "flash_attn_vec_tmp"); + + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (use_blk) { + GGML_ASSERT(has_mask); + + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + ggml_webgpu_create_buffer(ctx->global_ctx->device, blk_buf, blk_size_bytes, + wgpu::BufferUsage::Storage, + "flash_attn_vec_blk"); + have_blk_buf = true; + + const ggml_webgpu_flash_attn_blk_pipeline_key blk_key = { + .q_tile = decisions->q_tile, + .kv_tile = decisions->kv_tile, + }; + auto blk_it = ctx->flash_attn_blk_pipelines.find(blk_key); + if (blk_it != ctx->flash_attn_blk_pipelines.end()) { + blk_pipeline = blk_it->second; + } else { + ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { + .key = blk_key, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_blk_shader( + ctx->p, wgsl_flash_attn_vec_blk, blk_shader_ctx); + blk_pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), + processed.variant.c_str()); + ctx->flash_attn_blk_pipelines.emplace(blk_key, blk_pipeline); + } + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, + { .binding = 1, .buffer = blk_buf, .offset = 0, .size = blk_size_bytes }, + }; + } + + std::vector split_params = params; + if (use_pad) { + GGML_ASSERT(have_pad_buf); + split_params.push_back(0u); // offset_pad + split_params.push_back(pad_k_base_u32); // pad_k_base + split_params.push_back(pad_v_base_u32); // pad_v_base + split_params.push_back(pad_m_base_u32); // pad_m_base + split_params.push_back(pad_ncpsg); // ncpsg + split_params.push_back(1u); // nqptg + } + if (use_blk) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + }; + uint32_t split_binding_index = 3; + if (use_pad) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = pad_buf, + .offset = 0, + .size = pad_size_bytes }); + } + if (has_mask) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + if (use_blk) { + GGML_ASSERT(have_blk_buf); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = blk_buf, + .offset = 0, + .size = blk_size_bytes }); + } + split_entries.push_back({ .binding = split_binding_index++, + .buffer = tmp_buf, + .offset = 0, + .size = tmp_size_bytes }); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key reduce_key = { + .head_dim_v = (uint32_t) V->ne[0], + .wg_size = reduce_wg_size, + }; + webgpu_pipeline reduce_pipeline; + auto reduce_it = ctx->flash_attn_vec_reduce_pipelines.find(reduce_key); + if (reduce_it != ctx->flash_attn_vec_reduce_pipelines.end()) { + reduce_pipeline = reduce_it->second; + } else { + ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { + .key = reduce_key, + .max_wg_size = reduce_wg_size, + }; + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( + ctx->p, wgsl_flash_attn_vec_reduce, reduce_shader_ctx); + reduce_pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), + processed.variant.c_str()); + ctx->flash_attn_vec_reduce_pipelines.emplace(reduce_key, reduce_pipeline); + } + + std::vector reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + std::vector reduce_entries = { + { .binding = 0, .buffer = tmp_buf, .offset = 0, .size = tmp_size_bytes }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + std::vector pipelines; + std::vector> params_list; + std::vector> entries_list; + std::vector> workgroups_list; + + if (use_blk) { + pipelines.push_back(blk_pipeline); + params_list.push_back(std::move(blk_params)); + entries_list.push_back(std::move(blk_entries)); + workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count }); + } + pipelines.push_back(pipeline); + params_list.push_back(std::move(split_params)); + entries_list.push_back(std::move(split_entries)); + workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); + pipelines.push_back(reduce_pipeline); + params_list.push_back(std::move(reduce_params)); + entries_list.push_back(std::move(reduce_entries)); + workgroups_list.push_back({ (uint32_t) nrows, 1u }); + + const bool split_passes = use_blk; + + std::vector retained_buffers = { tmp_buf }; + if (use_pad) { + retained_buffers.push_back(pad_buf); + } + if (use_blk) { + retained_buffers.push_back(blk_buf); + } + + webgpu_command cmd; + if (use_pad) { + cmd = ggml_backend_webgpu_build_multi_with_pre_ops( + ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list, + pad_clear_ops, pad_copy_ops, std::move(retained_buffers), std::nullopt, split_passes); + } else { + cmd = ggml_backend_webgpu_build_multi( + ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list, + std::nullopt, split_passes); + cmd.retained_buffers = std::move(retained_buffers); + } + return cmd; + } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } #endif @@ -2247,27 +3045,58 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); +<<<<<<< HEAD std::vector commands; std::vector subs; uint32_t num_batched_kernels = 0; bool contains_set_rows = false; +======= + std::vector commands; + std::vector futures; + uint32_t batch_param_bufs = 0; + uint32_t num_batched_kernels = 0; +>>>>>>> 30923ffc9 (add vectorized flash attention) for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + const uint32_t cmd_param_bufs = (uint32_t) cmd->params_bufs.size(); + + // Leave room for the next command so alloc_bufs() never blocks waiting for + // a submit that has not yet happened. + if (!commands.empty() && batch_param_bufs + cmd_param_bufs + WEBGPU_MAX_PARAM_BUFS_PER_CMD > WEBGPU_NUM_PARAM_BUFS) { + futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, + &ctx->set_rows_error_buf_pool)); + ctx->global_ctx->instance.ProcessEvents(); + ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); + commands.clear(); + batch_param_bufs = 0; + num_batched_kernels = 0; + } + commands.push_back(*cmd); - num_batched_kernels += cmd.value().num_kernels; + batch_param_bufs += cmd_param_bufs; + num_batched_kernels += cmd->num_kernels; } +<<<<<<< HEAD if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { num_batched_kernels = 0; subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); +======= + if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE || + (!commands.empty() && batch_param_bufs + WEBGPU_MAX_PARAM_BUFS_PER_CMD > WEBGPU_NUM_PARAM_BUFS)) { + futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, + &ctx->set_rows_error_buf_pool)); +>>>>>>> 30923ffc9 (add vectorized flash attention) // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); commands.clear(); + batch_param_bufs = 0; + num_batched_kernels = 0; } } if (!commands.empty()) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl new file mode 100644 index 000000000000..34fc5b42591c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl @@ -0,0 +1,624 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +// Default values +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN +// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +// Each workgroup processes one subgroup matrix of Q rows +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 + +// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +// Quantization constants/helpers +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +// number of quantized elements processed per thread +#if defined(KV_Q4_0) +#define NQ 16 +// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +// Ok not to put these in a define block, compiler will remove if unused +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, + + // padding + offset_pad: u32, + pad_k_base: u32, + pad_v_base: u32, + pad_m_base: u32, + + ncpsg: u32, // number of context positions per group, used for padding logic + nqptg: u32, // number of Q positions per group, used for padding logic + + nwg: u32, // total number of workgroups, used for padding logic +}; + +@group(0) @binding(0) var Q: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array>; +@group(0) @binding(3) var pad: array; + +#if defined(MASK) && defined(SINKS) +@group(0) @binding(4) var mask: array; +@group(0) @binding(5) var sinks: array; +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#elif defined(MASK) +@group(0) @binding(4) var mask: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#elif defined(SINKS) +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif + +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +// The number of Q rows processed per workgroup +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; // output shmem + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// storage for output of Q*K^T scores for online softmax (S matrix from paper) +// also storage for diagonal matrix during online softmax (P matrix from paper) +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + let mask_term = slope * mask_val; + v += mask_term; +#endif + return v; +} + +fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + // batch index + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let tail = params.seq_len_kv % params.ncpsg; + let use_pad_tile = (tail != 0u) && (kv_tile + params.ncpsg >= params.seq_len_kv); + let kv_plane = k_head_idx + batch_idx * (params.n_heads / params.q_per_kv); + // clear inter_shmem to ensure zero-initialized accumulators + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } + + // load k tile into shared memory +// #if defined(KV_Q4_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let k_row = blck_idx / BLOCKS_K; +// let global_k_row = kv_tile + k_row; +// let block_k = blck_idx % BLOCKS_K; +// let row_offset = k_row * HEAD_DIM_QK; + +// if (global_k_row < params.seq_len_kv) { +// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = K[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = K[base_idx + 1u + block_offset + j]; +// let q_1 = K[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte(q_packed, k); +// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; +// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_lo; +// kv_shmem[row_offset + idx + 16u] = q_hi; +// } +// } +// } +// } +// #elif defined(KV_Q8_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let k_row = blck_idx / BLOCKS_K; +// let global_k_row = kv_tile + k_row; +// let block_k = blck_idx % BLOCKS_K; +// let row_offset = k_row * HEAD_DIM_QK; + +// if (global_k_row < params.seq_len_kv) { +// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = K[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = K[base_idx + 1u + block_offset + j]; +// let q_1 = K[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte_i32(q_packed, k); +// let q_val = f16(q_byte) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_val; +// } +// } +// } +// } +// #endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + // TODO: this loop seems to be the current largest bottleneck + // this bracket exists to scope the lifetime of variables, reducing register pressure + { + // vectorization + let num_of_threads = subgroup_size / 4u; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + continue; + } + + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += 4u) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = q_tile_row * HEAD_DIM_QK + i * 4u; + // let k_off = (kv_idx * HEAD_DIM_QK) + i * 4u; + var idx : u32 = 0u; + if (use_pad_tile == true) { + idx = params.offset_pad + params.pad_k_base + kv_plane * params.stride_k1 * params.ncpsg + kv_idx * params.stride_k1 + i * 4u; + + } else{ + idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + } + + let qv = vec4(f32(q_shmem[q_off]), f32(q_shmem[q_off + 1u]), f32(q_shmem[q_off + 2u]), f32(q_shmem[q_off + 3u])); + // let kv = vec4(f32(kv_shmem[k_off]), f32(kv_shmem[k_off + 1u]), f32(kv_shmem[k_off + 2u]), f32(kv_shmem[k_off + 3u])); + var kv :vec4; + if (use_pad_tile == true) { + kv = vec4(f32(pad[idx]), f32(pad[idx + 1u]), f32(pad[idx + 2u]), f32(pad[idx + 3u])); + } else { + kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); + } + partial_sum += dot(qv, kv); + + } + } + for (var g: u32 = 0u; g < 4u; g++) { + let kv_idx_g = kv_base + g; + let active_threads = (ty == g) && (kv_idx_g < KV_TILE) && ((kv_tile + kv_idx_g) < params.seq_len_kv); + + let contrib = select(0.0, partial_sum, active_threads); + let sum_g = subgroupAdd(contrib); + + if (tx == 0u && ty == g && kv_idx_g < KV_TILE) { + let dst_idx = q_tile_row * KV_TILE + kv_idx_g; + inter_shmem[dst_idx] = f16(select(FLOAT_MIN, sum_g, (kv_tile + kv_idx_g) < params.seq_len_kv)); + } + } + } + } + } + + +#ifdef MASK + // load mask tile into shared memory for this KV block + // TODO: optimize and skip if mask is -INF for the entire tile + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + if (use_pad_tile == true) { + let mask_batch_idx = select(0u, batch_idx, params.stride_mask3 > 0u); + let pad_mask_plane_base = params.offset_pad + params.pad_m_base + + mask_batch_idx * params.seq_len_q * params.ncpsg; + let mask_in_bounds = global_q_row < params.seq_len_q; + let mask_idx = pad_mask_plane_base + global_q_row * params.ncpsg + mask_col; + mask_shmem[elem_idx] = f16(select(0.0, pad[mask_idx], mask_in_bounds)); + } else { + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } + } +#endif + + workgroupBarrier(); + + // online softmax + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + // initialize running max for this row + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + + // load v tile into shared memory +// #if defined(KV_Q4_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let v_row = blck_idx / BLOCKS_V; +// let global_v_row = kv_tile + v_row; +// let block_k = blck_idx % BLOCKS_V; +// let row_offset = v_row * HEAD_DIM_V; + +// if (global_v_row < params.seq_len_kv) { +// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = V[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = V[base_idx + 1u + block_offset + j]; +// let q_1 = V[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte(q_packed, k); +// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; +// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_lo; +// kv_shmem[row_offset + idx + 16u] = q_hi; +// } +// } +// } +// } +// #elif defined(KV_Q8_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let v_row = blck_idx / BLOCKS_V; +// let global_v_row = kv_tile + v_row; +// let block_k = blck_idx % BLOCKS_V; +// let row_offset = v_row * HEAD_DIM_V; + +// if (global_v_row < params.seq_len_kv) { +// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = V[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = V[base_idx + 1u + block_offset + j]; +// let q_1 = V[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte_i32(q_packed, k); +// let q_val = f16(q_byte) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_val; +// } +// } +// } +// } +// #elif defined(KV_DIRECT) +// // Direct global loads for KV +// #else +// for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { +// let v_row = elem_idx / HEAD_DIM_V; +// let v_col = elem_idx % HEAD_DIM_V; +// let global_v_row = kv_tile + v_row; +// let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; +// kv_shmem[elem_idx] = f16(select( +// 0.0, +// V[global_v_row_offset + v_col], +// global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); +// } +// #endif + + workgroupBarrier(); + + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + // todo: load o_shmem + let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; + var acc = vec4(f32(o_shmem[o_base_idx]), f32(o_shmem[o_base_idx + 1u]), f32(o_shmem[o_base_idx + 2u]), f32(o_shmem[o_base_idx + 3u])); + for (var kv_idx : u32 = 0u; kv_idx < KV_TILE; kv_idx += 1u) { + let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + // let v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; + var v_idx : u32 = 0u; + var v4: vec4; + if (use_pad_tile == true) { + v_idx = params.offset_pad + params.pad_v_base + kv_plane * params.stride_v1 * params.ncpsg + kv_idx * params.stride_v1 + elem_base; + v4 = vec4(f32(pad[v_idx]), f32(pad[v_idx + 1u]), f32(pad[v_idx + 2u]), f32(pad[v_idx + 3u])); + } else { + v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; + v4 = vec4(V[v_idx/4u]); + } + acc += p * v4; + } + // todo: write acc back to o_shmem + o_shmem[o_base_idx] = f16(acc.x); + o_shmem[o_base_idx + 1u] = f16(acc.y); + o_shmem[o_base_idx + 2u] = f16(acc.z); + o_shmem[o_base_idx + 3u] = f16(acc.w); + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // add sinks (applied once after processing all KV tiles) + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + let val = f32(o_shmem[idx]) * max_exp; + o_shmem[idx] = f16(val); + } + } + workgroupBarrier(); +#endif + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } + + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + + let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl new file mode 100644 index 000000000000..d73b2207f320 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -0,0 +1,94 @@ +enable f16; + +#define Q_TILE 1 +#define KV_TILE 32 +#define WG_SIZE 32 + +struct Params { + offset_mask: u32, + seq_len_q: u32, + seq_len_kv: u32, + stride_mask3: u32, + nblk0: u32, + nblk1: u32, +}; + +@group(0) @binding(0) var mask: array; +@group(0) @binding(1) var blk: array; +@group(0) @binding(2) var params: Params; + +const MASK_MIN: f32 = -65504.0; +const MASK_MAX: f32 = 65504.0; +var wg_min: array; +var wg_max: array; +var wg_any: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + let kv_blk = wg_id.x; + let y = wg_id.y; + let q_blk = y % params.nblk1; + let batch_idx = y / params.nblk1; + if (kv_blk >= params.nblk0) { + return; + } + + let q_start = q_blk * Q_TILE; + let k_start = kv_blk * KV_TILE; + + let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3; + + var local_min = MASK_MAX; + var local_max = -MASK_MAX; + var local_any = 0u; + + for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { + let q_row = q_start + q_rel; + if (q_row >= params.seq_len_q) { + continue; + } + let row_base = mask_batch_base + q_row * params.seq_len_kv; + for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { + let k_col = k_start + k_rel; + if (k_col >= params.seq_len_kv) { + continue; + } + let mv = f32(mask[row_base + k_col]); + local_min = min(local_min, mv); + local_max = max(local_max, mv); + local_any = 1u; + } + } + + wg_min[local_id.x] = local_min; + wg_max[local_id.x] = local_max; + wg_any[local_id.x] = local_any; + workgroupBarrier(); + + if (local_id.x == 0u) { + var mmin = wg_min[0]; + var mmax = wg_max[0]; + var many = wg_any[0]; + for (var i = 1u; i < WG_SIZE; i += 1u) { + mmin = min(mmin, wg_min[i]); + mmax = max(mmax, wg_max[i]); + many = max(many, wg_any[i]); + } + + var state = 0u; + if (many != 0u) { + if (mmax <= MASK_MIN) { + state = 0u; + } else if (mmin == 0.0 && mmax == 0.0) { + state = 2u; + } else { + state = 1u; + } + } + + let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk; + blk[blk_idx] = state; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl new file mode 100644 index 000000000000..00524ad1db62 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -0,0 +1,78 @@ +enable f16; +enable subgroups; + +// Default values +#define HEAD_DIM_V 64 +#define WG_SIZE 128 + +struct Params { + nrows: u32, + seq_len_q: u32, + n_heads: u32, + nwg: u32, + tmp_data_base: u32, + tmp_stats_base: u32, +}; + +@group(0) @binding(0) var tmp: array; +@group(0) @binding(1) var dst: array>; +@group(0) @binding(2) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + let rid = wg_id.x; + if (rid >= params.nrows) { + return; + } + + let rows_per_batch = params.n_heads * params.seq_len_q; + let batch_idx = rid / rows_per_batch; + let rem = rid % rows_per_batch; + let head_idx = rem / params.seq_len_q; + let q_row = rem % params.seq_len_q; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + let row_base = batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; + + // Mirror Metal's vec-reduce mapping: each subgroup lane corresponds to one split index. + // This requires params.nwg <= subgroup_size. + let lane = sg_inv_id; + if (params.nwg > subgroup_size) { + return; + } + + let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); + let active_lane = lane < params.nwg; + let si = select(0.0, tmp[stats_base + 2u * lane + 0u], active_lane); + let mi = select(FLOAT_MIN, tmp[stats_base + 2u * lane + 1u], active_lane); + let m = subgroupMax(mi); + let ms = select(0.0, exp(mi - m), active_lane); + let s = subgroupAdd(si * ms); + let inv_s = select(0.0, 1.0 / s, s != 0.0); + + let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); + for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { + var weighted = vec4(0.0, 0.0, 0.0, 0.0); + if (active_lane) { + let src = row_tmp_base + lane * HEAD_DIM_V + elem_base; + weighted = vec4(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms; + } + + let sum_x = subgroupAdd(weighted.x); + let sum_y = subgroupAdd(weighted.y); + let sum_z = subgroupAdd(weighted.z); + let sum_w = subgroupAdd(weighted.w); + + if (lane == 0u) { + let dst_vec_index = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl new file mode 100644 index 000000000000..0f36f5af5edd --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -0,0 +1,784 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +// Default values +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN +// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +// Each workgroup processes one subgroup matrix of Q rows +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 + +// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +// Quantization constants/helpers +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +// number of quantized elements processed per thread +#if defined(KV_Q4_0) +#define NQ 16 +// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +// Ok not to put these in a define block, compiler will remove if unused +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, + +#ifdef PAD + offset_pad: u32, + pad_k_base: u32, + pad_v_base: u32, + pad_m_base: u32, + ncpsg: u32, + nqptg: u32, +#endif + +#ifdef BLK + blk_base: u32, + blk_nblk0: u32, + blk_nblk1: u32, +#endif + + tmp_data_base: u32, + tmp_stats_base: u32, + nwg: u32, +}; + +@group(0) @binding(0) var Q: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array>; + +#ifdef PAD +@group(0) @binding(3) var pad: array; +#endif + +#ifdef PAD +#if defined(MASK) && defined(SINKS) +@group(0) @binding(4) var mask: array; +@group(0) @binding(5) var sinks: array; +#ifdef BLK +#define BLK_BINDING 6 +#define TMP_BINDING 7 +#define DST_BINDING 8 +#define PARAMS_BINDING 9 +#else +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#endif +#elif defined(MASK) +@group(0) @binding(4) var mask: array; +#ifdef BLK +#define BLK_BINDING 5 +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#else +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#endif +#elif defined(SINKS) +@group(0) @binding(4) var sinks: array; +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#ifdef BLK +#define BLK_BINDING 5 +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#else +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#endif +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#endif + +#ifdef BLK +@group(0) @binding(BLK_BINDING) var blk: array; +#endif +@group(0) @binding(TMP_BINDING) var tmp: array; +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +// The number of Q rows processed per workgroup +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; // output shmem + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// storage for output of Q*K^T scores for online softmax (S matrix from paper) +// also storage for diagonal matrix during online softmax (P matrix from paper) +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; +var blk_state_wg: u32; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + if (apply_mask) { + let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + // Common fast path (mask only): avoid extra mul when bias scaling is disabled. + v += select(mask_val, slope * mask_val, has_bias); + } +#endif + return v; +} + +fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let iwg = wg_id.x % params.nwg; + let base_wg_id = wg_id.x / params.nwg; + + // batch index + let batch_idx = base_wg_id / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let wg_in_batch = base_wg_id % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let head = f32(head_idx); + let has_bias = params.max_bias > 0.0; + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { +#ifdef PAD + let tail = params.seq_len_kv % params.ncpsg; + let use_pad_tile = (tail != 0u) && (kv_tile + params.ncpsg >= params.seq_len_kv); + let kv_plane = k_head_idx + batch_idx * (params.n_heads / params.q_per_kv); +#endif +#ifdef BLK + let q_blk = q_row_start / Q_TILE; + let kv_blk = kv_tile / KV_TILE; + let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; + let blk_state_local = blk[blk_idx]; +#else + let blk_state_local = 1u; +#endif + if (local_id.x == 0u) { + blk_state_wg = blk_state_local; + } + workgroupBarrier(); + let blk_state = blk_state_wg; + let skip_tile = blk_state == 0u; + // clear inter_shmem to ensure zero-initialized accumulators + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } + + // load k tile into shared memory +// #if defined(KV_Q4_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let k_row = blck_idx / BLOCKS_K; +// let global_k_row = kv_tile + k_row; +// let block_k = blck_idx % BLOCKS_K; +// let row_offset = k_row * HEAD_DIM_QK; + +// if (global_k_row < params.seq_len_kv) { +// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = K[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = K[base_idx + 1u + block_offset + j]; +// let q_1 = K[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte(q_packed, k); +// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; +// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_lo; +// kv_shmem[row_offset + idx + 16u] = q_hi; +// } +// } +// } +// } +// #elif defined(KV_Q8_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let k_row = blck_idx / BLOCKS_K; +// let global_k_row = kv_tile + k_row; +// let block_k = blck_idx % BLOCKS_K; +// let row_offset = k_row * HEAD_DIM_QK; + +// if (global_k_row < params.seq_len_kv) { +// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = K[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = K[base_idx + 1u + block_offset + j]; +// let q_1 = K[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte_i32(q_packed, k); +// let q_val = f16(q_byte) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_val; +// } +// } +// } +// } +// #endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + // TODO: this loop seems to be the current largest bottleneck + // this bracket exists to scope the lifetime of variables, reducing register pressure + if (!skip_tile) { + // vectorization + let num_of_threads = subgroup_size / 4u; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + continue; + } + + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += 4u) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = q_tile_row * HEAD_DIM_QK + i * 4u; + // let k_off = (kv_idx * HEAD_DIM_QK) + i * 4u; + var idx: u32 = 0u; +#ifdef PAD + if (use_pad_tile) { + idx = params.offset_pad + params.pad_k_base + + kv_plane * params.stride_k1 * params.ncpsg + + kv_idx * params.stride_k1 + i * 4u; + } else { + idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + } +#else + idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); +#endif + + let qv = vec4(f32(q_shmem[q_off]), f32(q_shmem[q_off + 1u]), f32(q_shmem[q_off + 2u]), f32(q_shmem[q_off + 3u])); + // let kv = vec4(f32(kv_shmem[k_off]), f32(kv_shmem[k_off + 1u]), f32(kv_shmem[k_off + 2u]), f32(kv_shmem[k_off + 3u])); + var kv: vec4; +#ifdef PAD + if (use_pad_tile) { + kv = vec4(f32(pad[idx]), f32(pad[idx + 1u]), f32(pad[idx + 2u]), f32(pad[idx + 3u])); + } else { + kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); + } +#else + kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); +#endif + + partial_sum += dot(qv, kv); + + } + } + // Match Metal vec reduction pattern: reduce within each ty stripe using subgroup shuffles. + var sum = partial_sum; + if (num_of_threads <= 1u) { + sum += subgroupShuffleDown(sum, 16u); + } + if (num_of_threads <= 2u) { + sum += subgroupShuffleDown(sum, 8u); + } + if (num_of_threads <= 4u) { + sum += subgroupShuffleDown(sum, 4u); + } + if (num_of_threads <= 8u) { + sum += subgroupShuffleDown(sum, 2u); + } + if (num_of_threads <= 16u) { + sum += subgroupShuffleDown(sum, 1u); + } + + let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); + if (tx == 0u && kv_valid) { + let dst_idx = q_tile_row * KV_TILE + kv_idx; + inter_shmem[dst_idx] = f16(sum_bcast); + } + } + } + } + + +#ifdef MASK + let apply_mask = !skip_tile && (blk_state != 2u); + if (apply_mask) { + // load mask tile into shared memory for this KV block + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; +#ifdef PAD + if (use_pad_tile) { + let mask_batch_idx = select(0u, batch_idx, params.stride_mask3 > 0u); + let pad_mask_plane_base = params.offset_pad + params.pad_m_base + + mask_batch_idx * params.seq_len_q * params.ncpsg; + let mask_in_bounds = global_q_row < params.seq_len_q; + let mask_idx = pad_mask_plane_base + global_q_row * params.ncpsg + mask_col; + mask_shmem[elem_idx] = f16(select(0.0, pad[mask_idx], mask_in_bounds)); + } else { + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } +#else + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); +#endif + } + } +#else + let apply_mask = false; +#endif + + workgroupBarrier(); + + // online softmax + if (!skip_tile) { + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + // initialize running max for this row + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + } + + // load v tile into shared memory +// #if defined(KV_Q4_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let v_row = blck_idx / BLOCKS_V; +// let global_v_row = kv_tile + v_row; +// let block_k = blck_idx % BLOCKS_V; +// let row_offset = v_row * HEAD_DIM_V; + +// if (global_v_row < params.seq_len_kv) { +// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = V[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = V[base_idx + 1u + block_offset + j]; +// let q_1 = V[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte(q_packed, k); +// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; +// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_lo; +// kv_shmem[row_offset + idx + 16u] = q_hi; +// } +// } +// } +// } +// #elif defined(KV_Q8_0) +// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +// let blck_idx = elem_idx / BLOCK_SIZE; +// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; +// let v_row = blck_idx / BLOCKS_V; +// let global_v_row = kv_tile + v_row; +// let block_k = blck_idx % BLOCKS_V; +// let row_offset = v_row * HEAD_DIM_V; + +// if (global_v_row < params.seq_len_kv) { +// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; +// let base_idx = global_block_idx * F16_PER_BLOCK; +// let d = V[base_idx]; // scale +// for (var j = 0u; j < F16_PER_THREAD; j += 2) { +// let q_0 = V[base_idx + 1u + block_offset + j]; +// let q_1 = V[base_idx + 1u + block_offset + j + 1]; +// let q_packed = bitcast(vec2(q_0, q_1)); +// for (var k = 0u; k < 4u; k++) { +// let q_byte = get_byte_i32(q_packed, k); +// let q_val = f16(q_byte) * d; +// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; +// kv_shmem[row_offset + idx] = q_val; +// } +// } +// } +// } +// #elif defined(KV_DIRECT) +// // Direct global loads for KV +// #else +// for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { +// let v_row = elem_idx / HEAD_DIM_V; +// let v_col = elem_idx % HEAD_DIM_V; +// let global_v_row = kv_tile + v_row; +// let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; +// kv_shmem[elem_idx] = f16(select( +// 0.0, +// V[global_v_row_offset + v_col], +// global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); +// } +// #endif + + workgroupBarrier(); + + if (!skip_tile) { + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + let ne_lanes = 4u; + let nl_lanes = max(1u, subgroup_size / ne_lanes); + let tx_pv = sg_inv_id % nl_lanes; + let ty_pv = sg_inv_id / nl_lanes; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_lanes) { + var lo = vec4(0.0, 0.0, 0.0, 0.0); + for (var cc = 0u; cc < KV_TILE / ne_lanes; cc += 1u) { + let kv_idx = cc * ne_lanes + ty_pv; + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + + let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); + var v_idx: u32 = 0u; + var v4: vec4; +#ifdef PAD + if (use_pad_tile) { + v_idx = params.offset_pad + params.pad_v_base + + kv_plane * params.stride_v1 * params.ncpsg + + kv_idx * params.stride_v1 + vec_col * 4u; + v4 = vec4(f32(pad[v_idx]), f32(pad[v_idx + 1u]), f32(pad[v_idx + 2u]), f32(pad[v_idx + 3u])); + } else { + v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + v4 = vec4(V[v_idx / 4u]); + } +#else + v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + v4 = vec4(V[v_idx / 4u]); +#endif + lo += p * v4; + } + + // Match Metal's vec PV reduction: reduce across ty lanes. + var lo_x = lo.x; + var lo_y = lo.y; + var lo_z = lo.z; + var lo_w = lo.w; + var delta = nl_lanes * (ne_lanes >> 1u); + loop { + if (delta == 0u || delta < nl_lanes) { + break; + } + lo_x += subgroupShuffleDown(lo_x, delta); + lo_y += subgroupShuffleDown(lo_y, delta); + lo_z += subgroupShuffleDown(lo_z, delta); + lo_w += subgroupShuffleDown(lo_w, delta); + delta = delta >> 1u; + } + + if (ty_pv == 0u) { + let elem_base = vec_col * 4u; + let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; + o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x); + o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y); + o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z); + o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w); + } + } + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // Sinks are global terms and must be applied exactly once across split workgroups. + if (iwg == 0u) { + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + let val = f32(o_shmem[idx]) * max_exp; + o_shmem[idx] = f16(val); + } + } + workgroupBarrier(); + } +#endif + let rows_per_batch = params.n_heads * params.seq_len_q; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } + + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; + let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let tbase = tmp_row_data_base + elem_base; + tmp[tbase + 0u] = f32(o_shmem[i0]); + tmp[tbase + 1u] = f32(o_shmem[i1]); + tmp[tbase + 2u] = f32(o_shmem[i2]); + tmp[tbase + 3u] = f32(o_shmem[i3]); + } + + if (sg_inv_id == 0u) { + tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + } + } +} From 10330856b7be97175477cd25a306170e618da778 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 3 Mar 2026 12:19:24 -0800 Subject: [PATCH 03/34] update vec version --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 00aef8ed2056..c56f76f23c9f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -389,7 +389,6 @@ struct webgpu_context_struct { std::unordered_map flash_attn_blk_pipelines; - std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split From c307a4bf03c6bef0d33e5048bb7b8f2a71cdcb01 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 3 Mar 2026 12:41:11 -0800 Subject: [PATCH 04/34] remove unused path and shader --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 8 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 233 +------ .../wgsl-shaders/flash_attn_pad.wgsl | 624 ------------------ .../wgsl-shaders/flash_attn_vec.wgsl | 584 ---------------- 4 files changed, 5 insertions(+), 1444 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c8a7fadeab76..f80247bc8954 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -282,7 +282,6 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_sinks; bool uses_logit_softcap; bool use_vec; - bool use_pad; bool use_vec_split; bool use_blk; @@ -290,7 +289,7 @@ struct ggml_webgpu_flash_attn_pipeline_key { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec && - use_pad == other.use_pad && use_vec_split == other.use_vec_split && use_blk == other.use_blk; + use_vec_split == other.use_vec_split && use_blk == other.use_blk; } }; @@ -305,7 +304,6 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); ggml_webgpu_hash_combine(seed, key.use_vec); - ggml_webgpu_hash_combine(seed, key.use_pad); ggml_webgpu_hash_combine(seed, key.use_vec_split); ggml_webgpu_hash_combine(seed, key.use_blk); return seed; @@ -505,10 +503,6 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - if (context.key.use_pad) { - defines.push_back("PAD"); - variant += "_pad"; - } if (context.key.use_blk) { defines.push_back("BLK"); variant += "_blk"; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c56f76f23c9f..f5233645ca85 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1638,51 +1638,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const bool use_vec_split = use_vec && vec_nwg_cap > 1u; const bool use_blk = use_vec_split && has_mask; - // Compute vec KV tile (same logic as shader preprocessing) to decide whether tail padding is needed. - uint32_t kv_tile_for_vec = 0; - if (use_vec) { - ggml_webgpu_flash_attn_pipeline_key probe_key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .use_vec = true, - .use_pad = false, - .use_vec_split = use_vec_split, - .use_blk = false, - }; - ggml_webgpu_flash_attn_shader_lib_context probe_ctx = { - .key = probe_key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; - kv_tile_for_vec = std::min(ggml_webgpu_flash_attn_max_kv_tile(probe_ctx), - probe_ctx.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (probe_key.kv_direct) { - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile_for_vec != 0) { - kv_tile_for_vec -= probe_ctx.sg_mat_n; - } - } - } - - const uint32_t tail_for_vec = (use_vec && kv_tile_for_vec > 0) ? (uint32_t) (K->ne[1] % kv_tile_for_vec) : 0u; - const bool copy_alignment_ok = - (ggml_webgpu_tensor_offset(K) % 4 == 0) && - (ggml_webgpu_tensor_offset(V) % 4 == 0) && - (K->nb[1] % 4 == 0) && - (V->nb[1] % 4 == 0) && - (!has_mask || ((ggml_webgpu_tensor_offset(mask) % 4 == 0) && - ((((uint64_t) K->ne[1]) * ggml_type_size(mask->type) % 4) == 0) && - ((((uint64_t) mask->nb[3]) % 4) == 0) && - (((uint64_t) tail_for_vec * ggml_type_size(mask->type)) % 4 == 0))); - const bool use_pad = use_vec && tail_for_vec != 0 && copy_alignment_ok; - ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, .head_dim_qk = (uint32_t) Q->ne[0], @@ -1692,7 +1647,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .has_sinks = static_cast(has_sinks), .uses_logit_softcap = logit_softcap != 0.0f, .use_vec = use_vec, - .use_pad = use_pad, .use_vec_split = use_vec_split, .use_blk = use_blk, }; @@ -1714,10 +1668,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_processed_shader processed; if (use_vec_split) { processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_vec_split, shader_lib_ctx); - } else if (use_vec && use_pad) { - processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_pad, shader_lib_ctx); - } else if (use_vec) { - processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_vec, shader_lib_ctx); } else { processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); } @@ -1741,16 +1691,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, vec_nwg = std::min(vec_nwg, vec_nwg_cap); } - bool have_pad_buf = false; - wgpu::Buffer pad_buf = {}; - uint64_t pad_size_bytes = 0; - uint32_t pad_k_base_u32 = 0; - uint32_t pad_v_base_u32 = 0; - uint32_t pad_m_base_u32 = 0; - uint32_t pad_ncpsg = 0; - std::vector pad_clear_ops; - std::vector pad_copy_ops; - bool have_blk_buf = false; wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1758,147 +1698,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - if (use_pad) { - const uint32_t ncpsg = decisions->kv_tile; - const uint32_t tail = (uint32_t) (K->ne[1] % ncpsg); - const uint32_t tail_start = (uint32_t) K->ne[1] - tail; - GGML_ASSERT(tail > 0); - - const uint32_t stride_k1 = (uint32_t) (K->nb[1] / ggml_type_size(K->type)); - const uint32_t stride_v1 = (uint32_t) (V->nb[1] / ggml_type_size(V->type)); - const uint32_t kv_heads = (uint32_t) K->ne[2]; - const uint32_t kv_batches = (uint32_t) K->ne[3]; - const uint64_t kv_planes = (uint64_t) kv_heads * kv_batches; - const uint32_t stride_mask3 = has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0u; - const uint32_t mask_batch_planes = has_mask ? (stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u) : 0u; - - const uint64_t pad_k_base = 0; - const uint64_t pad_v_base = pad_k_base + (uint64_t) stride_k1 * ncpsg * kv_planes; - const uint64_t pad_m_base = pad_v_base + (uint64_t) stride_v1 * ncpsg * kv_planes; - const uint64_t pad_m_elems = has_mask ? ((uint64_t) Q->ne[1] * ncpsg * mask_batch_planes) : 0u; - const uint64_t pad_total_elems = pad_m_base + pad_m_elems; - pad_size_bytes = - ROUNDUP_POW2(pad_total_elems * ggml_type_size(K->type), WEBGPU_STORAGE_BUF_BINDING_MULT); - - GGML_ASSERT(pad_k_base <= UINT32_MAX); - GGML_ASSERT(pad_v_base <= UINT32_MAX); - GGML_ASSERT(pad_m_base <= UINT32_MAX); - pad_k_base_u32 = (uint32_t) pad_k_base; - pad_v_base_u32 = (uint32_t) pad_v_base; - pad_m_base_u32 = (uint32_t) pad_m_base; - pad_ncpsg = ncpsg; - - ggml_webgpu_create_buffer(ctx->global_ctx->device, pad_buf, pad_size_bytes, - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopyDst, - "flash_attn_pad_buf"); - have_pad_buf = true; - - pad_clear_ops = { - { .buffer = pad_buf, .offset = 0, .size = pad_size_bytes }, - }; - pad_copy_ops.clear(); - - const uint64_t k_tensor_off = ggml_webgpu_tensor_offset(K); - const uint64_t v_tensor_off = ggml_webgpu_tensor_offset(V); - const auto k_buf = ggml_webgpu_tensor_buf(K); - const auto v_buf = ggml_webgpu_tensor_buf(V); - - for (uint32_t b = 0; b < kv_batches; b++) { - for (uint32_t h = 0; h < kv_heads; h++) { - const uint64_t plane = (uint64_t) h + (uint64_t) b * kv_heads; - for (uint32_t r = 0; r < tail; r++) { - const uint64_t src_row_k = - k_tensor_off + (uint64_t) b * K->nb[3] + (uint64_t) h * K->nb[2] + (uint64_t) (tail_start + r) * K->nb[1]; - const uint64_t dst_row_k = - (pad_k_base + plane * stride_k1 * ncpsg + (uint64_t) r * stride_k1) * ggml_type_size(K->type); - pad_copy_ops.push_back({ .src = k_buf, - .src_offset = src_row_k, - .dst = pad_buf, - .dst_offset = dst_row_k, - .size = K->nb[1] }); - - const uint64_t src_row_v = - v_tensor_off + (uint64_t) b * V->nb[3] + (uint64_t) h * V->nb[2] + (uint64_t) (tail_start + r) * V->nb[1]; - const uint64_t dst_row_v = - (pad_v_base + plane * stride_v1 * ncpsg + (uint64_t) r * stride_v1) * ggml_type_size(V->type); - pad_copy_ops.push_back({ .src = v_buf, - .src_offset = src_row_v, - .dst = pad_buf, - .dst_offset = dst_row_v, - .size = V->nb[1] }); - } - } - } - - if (has_mask) { - const uint64_t mask_tensor_off = ggml_webgpu_tensor_offset(mask); - const auto mask_buf = ggml_webgpu_tensor_buf(mask); - const uint64_t mask_copy_size = (uint64_t) tail * ggml_type_size(mask->type); - for (uint32_t mb = 0; mb < mask_batch_planes; mb++) { - const uint32_t src_batch = - (stride_mask3 > 0 && mb >= (uint32_t) mask->ne[3]) ? (uint32_t) mask->ne[3] - 1 : mb; - for (uint32_t q = 0; q < (uint32_t) Q->ne[1]; q++) { - const uint64_t src_mask_elem = - (uint64_t) src_batch * stride_mask3 + (uint64_t) q * (uint32_t) K->ne[1] + tail_start; - const uint64_t dst_mask_elem = - pad_m_base + (uint64_t) mb * (uint32_t) Q->ne[1] * ncpsg + (uint64_t) q * ncpsg; - pad_copy_ops.push_back({ .src = mask_buf, - .src_offset = mask_tensor_off + src_mask_elem * ggml_type_size(mask->type), - .dst = pad_buf, - .dst_offset = dst_mask_elem * ggml_type_size(K->type), - .size = mask_copy_size }); - } - } - } - - if (!use_vec_split) { - std::vector pad_params = params; - pad_params.push_back(0u); // offset_pad - pad_params.push_back(pad_k_base_u32); // pad_k_base - pad_params.push_back(pad_v_base_u32); // pad_v_base - pad_params.push_back(pad_m_base_u32); // pad_m_base - pad_params.push_back(pad_ncpsg); // ncpsg - pad_params.push_back(1u); // nqptg - pad_params.push_back(1u); // nwg - - std::vector pad_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, - { .binding = 3, .buffer = pad_buf, .offset = 0, .size = pad_size_bytes }, - }; - uint32_t pad_binding_index = 4; - if (has_mask) { - pad_entries.push_back({ .binding = pad_binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); - } - if (has_sinks) { - pad_entries.push_back({ .binding = pad_binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - pad_entries.push_back({ .binding = pad_binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - return ggml_backend_webgpu_build_with_pre_ops(ctx->global_ctx, ctx->param_buf_pool, pipeline, pad_params, - pad_entries, pad_clear_ops, pad_copy_ops, { pad_buf }, wg_x); - } - } - if (use_vec_split) { const uint32_t nwg = vec_nwg; GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); @@ -1972,15 +1771,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, } std::vector split_params = params; - if (use_pad) { - GGML_ASSERT(have_pad_buf); - split_params.push_back(0u); // offset_pad - split_params.push_back(pad_k_base_u32); // pad_k_base - split_params.push_back(pad_v_base_u32); // pad_v_base - split_params.push_back(pad_m_base_u32); // pad_m_base - split_params.push_back(pad_ncpsg); // ncpsg - split_params.push_back(1u); // nqptg - } if (use_blk) { split_params.push_back(0u); // blk_base split_params.push_back(blk_nblk0); // blk_nblk0 @@ -2005,12 +1795,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, V) }, }; uint32_t split_binding_index = 3; - if (use_pad) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = pad_buf, - .offset = 0, - .size = pad_size_bytes }); - } if (has_mask) { split_entries.push_back({ .binding = split_binding_index++, .buffer = ggml_webgpu_tensor_buf(mask), @@ -2103,24 +1887,15 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const bool split_passes = use_blk; std::vector retained_buffers = { tmp_buf }; - if (use_pad) { - retained_buffers.push_back(pad_buf); - } if (use_blk) { retained_buffers.push_back(blk_buf); } webgpu_command cmd; - if (use_pad) { - cmd = ggml_backend_webgpu_build_multi_with_pre_ops( - ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list, - pad_clear_ops, pad_copy_ops, std::move(retained_buffers), std::nullopt, split_passes); - } else { - cmd = ggml_backend_webgpu_build_multi( - ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list, - std::nullopt, split_passes); - cmd.retained_buffers = std::move(retained_buffers); - } + cmd = ggml_backend_webgpu_build_multi( + ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list, + std::nullopt, split_passes); + cmd.retained_buffers = std::move(retained_buffers); return cmd; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl deleted file mode 100644 index 34fc5b42591c..000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_pad.wgsl +++ /dev/null @@ -1,624 +0,0 @@ -diagnostic(off, chromium.subgroup_matrix_uniformity); -diagnostic(off, subgroup_uniformity); -enable f16; -enable subgroups; -enable chromium_experimental_subgroup_matrix; - -#ifdef KV_F32 -#define KV_TYPE f32 -#else -#define KV_TYPE f16 -#endif - -// Default values -#define HEAD_DIM_QK 64 -#define HEAD_DIM_V 64 - -// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN -// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. -#define SG_MAT_M 8 -#define SG_MAT_N 8 -#define SG_MAT_K 8 - -// Each workgroup processes one subgroup matrix of Q rows -#define Q_TILE SG_MAT_M -#define KV_TILE 16 -#define WG_SIZE 64 - -// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. -#define KV_BLOCKS (KV_TILE / SG_MAT_N) - -// Quantization constants/helpers -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread -#if defined(KV_Q4_0) -#define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights -#define F16_PER_BLOCK 9 -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights -#define F16_PER_BLOCK 17 -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -// Ok not to put these in a define block, compiler will remove if unused -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - -struct Params { - offset_q: u32, - offset_k: u32, - offset_v: u32, - offset_mask: u32, - offset_sinks: u32, - offset_dst: u32, - - // shapes of Q/K/V - n_heads: u32, - seq_len_q: u32, - seq_len_kv: u32, - - // strides (in elements) - stride_q1: u32, - stride_q2: u32, - stride_q3: u32, - stride_k1: u32, - stride_k2: u32, - stride_k3: u32, - stride_v1: u32, - stride_v2: u32, - stride_v3: u32, - stride_mask3: u32, - - // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA - q_per_kv: u32, - - // softmax params - scale: f32, - max_bias: f32, - logit_softcap: f32, - n_head_log2: f32, - m0: f32, - m1: f32, - - // padding - offset_pad: u32, - pad_k_base: u32, - pad_v_base: u32, - pad_m_base: u32, - - ncpsg: u32, // number of context positions per group, used for padding logic - nqptg: u32, // number of Q positions per group, used for padding logic - - nwg: u32, // total number of workgroups, used for padding logic -}; - -@group(0) @binding(0) var Q: array; -@group(0) @binding(1) var K: array; -@group(0) @binding(2) var V: array>; -@group(0) @binding(3) var pad: array; - -#if defined(MASK) && defined(SINKS) -@group(0) @binding(4) var mask: array; -@group(0) @binding(5) var sinks: array; -#define DST_BINDING 6 -#define PARAMS_BINDING 7 -#elif defined(MASK) -@group(0) @binding(4) var mask: array; -#define DST_BINDING 5 -#define PARAMS_BINDING 6 -#elif defined(SINKS) -@group(0) @binding(4) var sinks: array; -#define DST_BINDING 5 -#define PARAMS_BINDING 6 -#else -#define DST_BINDING 4 -#define PARAMS_BINDING 5 -#endif - -@group(0) @binding(DST_BINDING) var dst: array>; -@group(0) @binding(PARAMS_BINDING) var params: Params; - -// Just a very small float value. -const FLOAT_MIN: f32 = -1.0e9; - -// The number of Q rows processed per workgroup -var q_shmem: array; - -#ifndef KV_DIRECT -const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); -// we can reuse the same shmem for K and V since we only need one at a time -var kv_shmem: array; -#endif - -var o_shmem: array; // output shmem - -#ifdef MASK -// storage for mask values -var mask_shmem: array; -#endif - -// storage for output of Q*K^T scores for online softmax (S matrix from paper) -// also storage for diagonal matrix during online softmax (P matrix from paper) -// note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; - -// Storage for row max and exp sum during online softmax -var row_max_shmem: array; -var exp_sum_shmem: array; - -fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { - var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, - kv_idx < KV_TILE); -#ifdef LOGIT_SOFTCAP - v = params.logit_softcap * tanh(v); -#endif -#ifdef MASK - let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); - let mask_term = slope * mask_val; - v += mask_term; -#endif - return v; -} - -fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { - return (*buf)[scalar_index >> 2u]; -} - -fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { - return (*buf)[scalar_index >> 2u]; -} - -@compute @workgroup_size(WG_SIZE) -fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32, - @builtin(subgroup_size) subgroup_size: u32, - @builtin(num_subgroups) num_subgroups: u32, - @builtin(subgroup_invocation_id) sg_inv_id: u32) { - - // initialize row max for online softmax - for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { - row_max_shmem[i] = FLOAT_MIN; - exp_sum_shmem[i] = 0.0; - } - - for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { - o_shmem[i] = 0.0; - } - - // workgroups per head/batch - let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; - let wg_per_batch = wg_per_head * params.n_heads; - - let dst2_stride = HEAD_DIM_V * params.n_heads; - let dst3_stride = dst2_stride * params.seq_len_q; - - // batch index - let batch_idx = wg_id.x / wg_per_batch; - let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; - let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; - let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; - let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; - let wg_in_batch = wg_id.x % wg_per_batch; - - // head index - let head_idx = wg_in_batch / wg_per_head; - let q_head_offset = q_batch_offset + head_idx * params.stride_q2; - let k_head_idx = head_idx / params.q_per_kv; - let v_head_idx = k_head_idx; - let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; - let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; - - // starting Q row for this workgroup - let wg_in_head = wg_in_batch % wg_per_head; - let q_row_start = wg_in_head * Q_TILE; - -#ifdef MASK - // mask offset - let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; -#endif - - // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] - let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; - - let head = f32(head_idx); - let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); - - // load q tile into shared memory - for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let q_row = elem_idx / HEAD_DIM_QK; - let q_col = elem_idx % HEAD_DIM_QK; - let head_q_row = q_row_start + q_row; - let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; - q_shmem[elem_idx] = f16(select( - 0.0, - Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); - } - - for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { - let tail = params.seq_len_kv % params.ncpsg; - let use_pad_tile = (tail != 0u) && (kv_tile + params.ncpsg >= params.seq_len_kv); - let kv_plane = k_head_idx + batch_idx * (params.n_heads / params.q_per_kv); - // clear inter_shmem to ensure zero-initialized accumulators - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = 0.0; - } - - // load k tile into shared memory -// #if defined(KV_Q4_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let k_row = blck_idx / BLOCKS_K; -// let global_k_row = kv_tile + k_row; -// let block_k = blck_idx % BLOCKS_K; -// let row_offset = k_row * HEAD_DIM_QK; - -// if (global_k_row < params.seq_len_kv) { -// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = K[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = K[base_idx + 1u + block_offset + j]; -// let q_1 = K[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte(q_packed, k); -// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; -// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_lo; -// kv_shmem[row_offset + idx + 16u] = q_hi; -// } -// } -// } -// } -// #elif defined(KV_Q8_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let k_row = blck_idx / BLOCKS_K; -// let global_k_row = kv_tile + k_row; -// let block_k = blck_idx % BLOCKS_K; -// let row_offset = k_row * HEAD_DIM_QK; - -// if (global_k_row < params.seq_len_kv) { -// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = K[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = K[base_idx + 1u + block_offset + j]; -// let q_1 = K[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte_i32(q_packed, k); -// let q_val = f16(q_byte) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_val; -// } -// } -// } -// } -// #endif - - workgroupBarrier(); - - // accumulate q block * k block into registers across the entire KV tile - // TODO: this loop seems to be the current largest bottleneck - // this bracket exists to scope the lifetime of variables, reducing register pressure - { - // vectorization - let num_of_threads = subgroup_size / 4u; - let tx = sg_inv_id % num_of_threads; - let ty = sg_inv_id / num_of_threads; - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - continue; - } - - for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += 4u) { - let kv_idx = kv_base + ty; - var partial_sum: f32 = 0.0; - let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; - - if (kv_valid) { - for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { - let q_off = q_tile_row * HEAD_DIM_QK + i * 4u; - // let k_off = (kv_idx * HEAD_DIM_QK) + i * 4u; - var idx : u32 = 0u; - if (use_pad_tile == true) { - idx = params.offset_pad + params.pad_k_base + kv_plane * params.stride_k1 * params.ncpsg + kv_idx * params.stride_k1 + i * 4u; - - } else{ - idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); - } - - let qv = vec4(f32(q_shmem[q_off]), f32(q_shmem[q_off + 1u]), f32(q_shmem[q_off + 2u]), f32(q_shmem[q_off + 3u])); - // let kv = vec4(f32(kv_shmem[k_off]), f32(kv_shmem[k_off + 1u]), f32(kv_shmem[k_off + 2u]), f32(kv_shmem[k_off + 3u])); - var kv :vec4; - if (use_pad_tile == true) { - kv = vec4(f32(pad[idx]), f32(pad[idx + 1u]), f32(pad[idx + 2u]), f32(pad[idx + 3u])); - } else { - kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); - } - partial_sum += dot(qv, kv); - - } - } - for (var g: u32 = 0u; g < 4u; g++) { - let kv_idx_g = kv_base + g; - let active_threads = (ty == g) && (kv_idx_g < KV_TILE) && ((kv_tile + kv_idx_g) < params.seq_len_kv); - - let contrib = select(0.0, partial_sum, active_threads); - let sum_g = subgroupAdd(contrib); - - if (tx == 0u && ty == g && kv_idx_g < KV_TILE) { - let dst_idx = q_tile_row * KV_TILE + kv_idx_g; - inter_shmem[dst_idx] = f16(select(FLOAT_MIN, sum_g, (kv_tile + kv_idx_g) < params.seq_len_kv)); - } - } - } - } - } - - -#ifdef MASK - // load mask tile into shared memory for this KV block - // TODO: optimize and skip if mask is -INF for the entire tile - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - let mask_row = elem_idx / KV_TILE; - let mask_col = elem_idx % KV_TILE; - let global_q_row = q_row_start + mask_row; - let global_k_col = kv_tile + mask_col; - if (use_pad_tile == true) { - let mask_batch_idx = select(0u, batch_idx, params.stride_mask3 > 0u); - let pad_mask_plane_base = params.offset_pad + params.pad_m_base + - mask_batch_idx * params.seq_len_q * params.ncpsg; - let mask_in_bounds = global_q_row < params.seq_len_q; - let mask_idx = pad_mask_plane_base + global_q_row * params.ncpsg + mask_col; - mask_shmem[elem_idx] = f16(select(0.0, pad[mask_idx], mask_in_bounds)); - } else { - let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; - let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; - mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); - } - } -#endif - - workgroupBarrier(); - - // online softmax - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - // initialize running max for this row - var prev_max = row_max_shmem[q_tile_row]; - var final_max = prev_max; - // pass 1: compute final max across the full KV tile in chunks - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); - final_max = subgroupMax(max(final_max, softmax_term)); - } - - var total_exp_term: f32 = 0.0; - // pass 2: compute exp sum and write P using final_max - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); - let cur_p = select(0.0, - exp(softmax_term - final_max), - kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); - total_exp_term += subgroupAdd(cur_p); - if (kv_idx < KV_TILE) { - inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); - } - } - - let cur_exp = exp(prev_max - final_max); - - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = final_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); - } - } - - // load v tile into shared memory -// #if defined(KV_Q4_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let v_row = blck_idx / BLOCKS_V; -// let global_v_row = kv_tile + v_row; -// let block_k = blck_idx % BLOCKS_V; -// let row_offset = v_row * HEAD_DIM_V; - -// if (global_v_row < params.seq_len_kv) { -// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = V[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = V[base_idx + 1u + block_offset + j]; -// let q_1 = V[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte(q_packed, k); -// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; -// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_lo; -// kv_shmem[row_offset + idx + 16u] = q_hi; -// } -// } -// } -// } -// #elif defined(KV_Q8_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let v_row = blck_idx / BLOCKS_V; -// let global_v_row = kv_tile + v_row; -// let block_k = blck_idx % BLOCKS_V; -// let row_offset = v_row * HEAD_DIM_V; - -// if (global_v_row < params.seq_len_kv) { -// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = V[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = V[base_idx + 1u + block_offset + j]; -// let q_1 = V[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte_i32(q_packed, k); -// let q_val = f16(q_byte) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_val; -// } -// } -// } -// } -// #elif defined(KV_DIRECT) -// // Direct global loads for KV -// #else -// for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { -// let v_row = elem_idx / HEAD_DIM_V; -// let v_col = elem_idx % HEAD_DIM_V; -// let global_v_row = kv_tile + v_row; -// let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; -// kv_shmem[elem_idx] = f16(select( -// 0.0, -// V[global_v_row_offset + v_col], -// global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); -// } -// #endif - - workgroupBarrier(); - - // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem - // we want to compute O += P * V across the full KV tile - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - // todo: load o_shmem - let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; - var acc = vec4(f32(o_shmem[o_base_idx]), f32(o_shmem[o_base_idx + 1u]), f32(o_shmem[o_base_idx + 2u]), f32(o_shmem[o_base_idx + 3u])); - for (var kv_idx : u32 = 0u; kv_idx < KV_TILE; kv_idx += 1u) { - let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); - let v_row = kv_tile + kv_idx; - if (v_row >= params.seq_len_kv) { - continue; - } - // let v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; - var v_idx : u32 = 0u; - var v4: vec4; - if (use_pad_tile == true) { - v_idx = params.offset_pad + params.pad_v_base + kv_plane * params.stride_v1 * params.ncpsg + kv_idx * params.stride_v1 + elem_base; - v4 = vec4(f32(pad[v_idx]), f32(pad[v_idx + 1u]), f32(pad[v_idx + 2u]), f32(pad[v_idx + 3u])); - } else { - v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; - v4 = vec4(V[v_idx/4u]); - } - acc += p * v4; - } - // todo: write acc back to o_shmem - o_shmem[o_base_idx] = f16(acc.x); - o_shmem[o_base_idx + 1u] = f16(acc.y); - o_shmem[o_base_idx + 2u] = f16(acc.z); - o_shmem[o_base_idx + 3u] = f16(acc.w); - } - } - - workgroupBarrier(); - } - - -#ifdef SINKS - // add sinks (applied once after processing all KV tiles) - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - // no need to process rows beyond seq_len_q - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - var prev_max = row_max_shmem[q_tile_row]; - - // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum - let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); - let new_max = subgroupMax(max(prev_max, sink_val)); - let max_exp = exp(prev_max - new_max); - let sink_exp = exp(sink_val - new_max); - - let sink_exp_sum = subgroupAdd(sink_exp); - - if (sg_inv_id == 0) { - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - let val = f32(o_shmem[idx]) * max_exp; - o_shmem[idx] = f16(val); - } - } - workgroupBarrier(); -#endif - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { break; } - - let exp_sum = exp_sum_shmem[q_tile_row]; - let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); - - let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride; - - for (var elem_base = sg_inv_id * 4u; - elem_base < HEAD_DIM_V; - elem_base += subgroup_size * 4u) { - - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - - let v = vec4( - f32(o_shmem[i0]) * scale, - f32(o_shmem[i1]) * scale, - f32(o_shmem[i2]) * scale, - f32(o_shmem[i3]) * scale - ); - - let dst_vec_index: u32 = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = v; - } - } -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl deleted file mode 100644 index 6d2023e4658c..000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec.wgsl +++ /dev/null @@ -1,584 +0,0 @@ -diagnostic(off, chromium.subgroup_matrix_uniformity); -diagnostic(off, subgroup_uniformity); -enable f16; -enable subgroups; -enable chromium_experimental_subgroup_matrix; - -#ifdef KV_F32 -#define KV_TYPE f32 -#else -#define KV_TYPE f16 -#endif - -// Default values -#define HEAD_DIM_QK 64 -#define HEAD_DIM_V 64 - -// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN -// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. -#define SG_MAT_M 8 -#define SG_MAT_N 8 -#define SG_MAT_K 8 - -// Each workgroup processes one subgroup matrix of Q rows -#define Q_TILE SG_MAT_M -#define KV_TILE 16 -#define WG_SIZE 64 - -// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. -#define KV_BLOCKS (KV_TILE / SG_MAT_N) - -// Quantization constants/helpers -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread -#if defined(KV_Q4_0) -#define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights -#define F16_PER_BLOCK 9 -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights -#define F16_PER_BLOCK 17 -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -// Ok not to put these in a define block, compiler will remove if unused -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - -struct Params { - offset_q: u32, - offset_k: u32, - offset_v: u32, - offset_mask: u32, - offset_sinks: u32, - offset_dst: u32, - - // shapes of Q/K/V - n_heads: u32, - seq_len_q: u32, - seq_len_kv: u32, - - // strides (in elements) - stride_q1: u32, - stride_q2: u32, - stride_q3: u32, - stride_k1: u32, - stride_k2: u32, - stride_k3: u32, - stride_v1: u32, - stride_v2: u32, - stride_v3: u32, - stride_mask3: u32, - - // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA - q_per_kv: u32, - - // softmax params - scale: f32, - max_bias: f32, - logit_softcap: f32, - n_head_log2: f32, - m0: f32, - m1: f32, -}; - -@group(0) @binding(0) var Q: array; -@group(0) @binding(1) var K: array; -@group(0) @binding(2) var V: array>; - -#if defined(MASK) && defined(SINKS) -@group(0) @binding(3) var mask: array; -@group(0) @binding(4) var sinks: array; -#define DST_BINDING 5 -#define PARAMS_BINDING 6 -#elif defined(MASK) -@group(0) @binding(3) var mask: array; -#define DST_BINDING 4 -#define PARAMS_BINDING 5 -#elif defined(SINKS) -@group(0) @binding(3) var sinks: array; -#define DST_BINDING 4 -#define PARAMS_BINDING 5 -#else -#define DST_BINDING 3 -#define PARAMS_BINDING 4 -#endif - -@group(0) @binding(DST_BINDING) var dst: array>; -@group(0) @binding(PARAMS_BINDING) var params: Params; - -// Just a very small float value. -const FLOAT_MIN: f32 = -1.0e9; - -// The number of Q rows processed per workgroup -var q_shmem: array; - -#ifndef KV_DIRECT -const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); -// we can reuse the same shmem for K and V since we only need one at a time -var kv_shmem: array; -#endif - -var o_shmem: array; // output shmem - -#ifdef MASK -// storage for mask values -var mask_shmem: array; -#endif - -// storage for output of Q*K^T scores for online softmax (S matrix from paper) -// also storage for diagonal matrix during online softmax (P matrix from paper) -// note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; - -// Storage for row max and exp sum during online softmax -var row_max_shmem: array; -var exp_sum_shmem: array; - -fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { - var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, - kv_idx < KV_TILE); -#ifdef LOGIT_SOFTCAP - v = params.logit_softcap * tanh(v); -#endif -#ifdef MASK - let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); - let mask_term = slope * mask_val; - v += mask_term; -#endif - return v; -} - -fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { - return (*buf)[scalar_index >> 2u]; -} - -fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { - return (*buf)[scalar_index >> 2u]; -} - -@compute @workgroup_size(WG_SIZE) -fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32, - @builtin(subgroup_size) subgroup_size: u32, - @builtin(num_subgroups) num_subgroups: u32, - @builtin(subgroup_invocation_id) sg_inv_id: u32) { - - // initialize row max for online softmax - for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { - row_max_shmem[i] = FLOAT_MIN; - exp_sum_shmem[i] = 0.0; - } - - for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { - o_shmem[i] = 0.0; - } - - // workgroups per head/batch - let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; - let wg_per_batch = wg_per_head * params.n_heads; - - let dst2_stride = HEAD_DIM_V * params.n_heads; - let dst3_stride = dst2_stride * params.seq_len_q; - - // batch index - let batch_idx = wg_id.x / wg_per_batch; - let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; - let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; - let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; - let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; - let wg_in_batch = wg_id.x % wg_per_batch; - - // head index - let head_idx = wg_in_batch / wg_per_head; - let q_head_offset = q_batch_offset + head_idx * params.stride_q2; - let k_head_idx = head_idx / params.q_per_kv; - let v_head_idx = k_head_idx; - let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; - let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; - - // starting Q row for this workgroup - let wg_in_head = wg_in_batch % wg_per_head; - let q_row_start = wg_in_head * Q_TILE; - -#ifdef MASK - // mask offset - let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; -#endif - - // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] - let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; - - let head = f32(head_idx); - let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); - - // load q tile into shared memory - for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let q_row = elem_idx / HEAD_DIM_QK; - let q_col = elem_idx % HEAD_DIM_QK; - let head_q_row = q_row_start + q_row; - let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; - q_shmem[elem_idx] = f16(select( - 0.0, - Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); - } - - for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { - // clear inter_shmem to ensure zero-initialized accumulators - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = 0.0; - } - - // load k tile into shared memory -// #if defined(KV_Q4_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let k_row = blck_idx / BLOCKS_K; -// let global_k_row = kv_tile + k_row; -// let block_k = blck_idx % BLOCKS_K; -// let row_offset = k_row * HEAD_DIM_QK; - -// if (global_k_row < params.seq_len_kv) { -// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = K[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = K[base_idx + 1u + block_offset + j]; -// let q_1 = K[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte(q_packed, k); -// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; -// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_lo; -// kv_shmem[row_offset + idx + 16u] = q_hi; -// } -// } -// } -// } -// #elif defined(KV_Q8_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let k_row = blck_idx / BLOCKS_K; -// let global_k_row = kv_tile + k_row; -// let block_k = blck_idx % BLOCKS_K; -// let row_offset = k_row * HEAD_DIM_QK; - -// if (global_k_row < params.seq_len_kv) { -// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = K[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = K[base_idx + 1u + block_offset + j]; -// let q_1 = K[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte_i32(q_packed, k); -// let q_val = f16(q_byte) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_val; -// } -// } -// } -// } -// #endif - - workgroupBarrier(); - - // accumulate q block * k block into registers across the entire KV tile - // TODO: this loop seems to be the current largest bottleneck - // this bracket exists to scope the lifetime of variables, reducing register pressure - { - // vectorization - let num_of_threads = subgroup_size / 4u; - let tx = sg_inv_id % num_of_threads; - let ty = sg_inv_id / num_of_threads; - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - continue; - } - - for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += 4u) { - let kv_idx = kv_base + ty; - var partial_sum: f32 = 0.0; - let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; - - if (kv_valid) { - for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { - let q_off = q_tile_row * HEAD_DIM_QK + i * 4u; - // let k_off = (kv_idx * HEAD_DIM_QK) + i * 4u; - let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); - - let qv = vec4(f32(q_shmem[q_off]), f32(q_shmem[q_off + 1u]), f32(q_shmem[q_off + 2u]), f32(q_shmem[q_off + 3u])); - // let kv = vec4(f32(kv_shmem[k_off]), f32(kv_shmem[k_off + 1u]), f32(kv_shmem[k_off + 2u]), f32(kv_shmem[k_off + 3u])); - let kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); - - partial_sum += dot(qv, kv); - - } - } - for (var g: u32 = 0u; g < 4u; g++) { - let kv_idx_g = kv_base + g; - let active_threads = (ty == g) && (kv_idx_g < KV_TILE) && ((kv_tile + kv_idx_g) < params.seq_len_kv); - - let contrib = select(0.0, partial_sum, active_threads); - let sum_g = subgroupAdd(contrib); - - if (tx == 0u && ty == g && kv_idx_g < KV_TILE) { - let dst_idx = q_tile_row * KV_TILE + kv_idx_g; - inter_shmem[dst_idx] = f16(select(FLOAT_MIN, sum_g, (kv_tile + kv_idx_g) < params.seq_len_kv)); - } - } - } - } - } - - -#ifdef MASK - // load mask tile into shared memory for this KV block - // TODO: optimize and skip if mask is -INF for the entire tile - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - let mask_row = elem_idx / KV_TILE; - let mask_col = elem_idx % KV_TILE; - let global_q_row = q_row_start + mask_row; - let global_k_col = kv_tile + mask_col; - let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; - let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; - mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); - } -#endif - - workgroupBarrier(); - - // online softmax - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - // initialize running max for this row - var prev_max = row_max_shmem[q_tile_row]; - var final_max = prev_max; - // pass 1: compute final max across the full KV tile in chunks - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); - final_max = subgroupMax(max(final_max, softmax_term)); - } - - var total_exp_term: f32 = 0.0; - // pass 2: compute exp sum and write P using final_max - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); - let cur_p = select(0.0, - exp(softmax_term - final_max), - kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); - total_exp_term += subgroupAdd(cur_p); - if (kv_idx < KV_TILE) { - inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); - } - } - - let cur_exp = exp(prev_max - final_max); - - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = final_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); - } - } - - // load v tile into shared memory -// #if defined(KV_Q4_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let v_row = blck_idx / BLOCKS_V; -// let global_v_row = kv_tile + v_row; -// let block_k = blck_idx % BLOCKS_V; -// let row_offset = v_row * HEAD_DIM_V; - -// if (global_v_row < params.seq_len_kv) { -// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = V[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = V[base_idx + 1u + block_offset + j]; -// let q_1 = V[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte(q_packed, k); -// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; -// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_lo; -// kv_shmem[row_offset + idx + 16u] = q_hi; -// } -// } -// } -// } -// #elif defined(KV_Q8_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let v_row = blck_idx / BLOCKS_V; -// let global_v_row = kv_tile + v_row; -// let block_k = blck_idx % BLOCKS_V; -// let row_offset = v_row * HEAD_DIM_V; - -// if (global_v_row < params.seq_len_kv) { -// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = V[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = V[base_idx + 1u + block_offset + j]; -// let q_1 = V[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte_i32(q_packed, k); -// let q_val = f16(q_byte) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_val; -// } -// } -// } -// } -// #elif defined(KV_DIRECT) -// // Direct global loads for KV -// #else -// for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { -// let v_row = elem_idx / HEAD_DIM_V; -// let v_col = elem_idx % HEAD_DIM_V; -// let global_v_row = kv_tile + v_row; -// let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; -// kv_shmem[elem_idx] = f16(select( -// 0.0, -// V[global_v_row_offset + v_col], -// global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); -// } -// #endif - - workgroupBarrier(); - - // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem - // we want to compute O += P * V across the full KV tile - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - // todo: load o_shmem - let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; - var acc = vec4(f32(o_shmem[o_base_idx]), f32(o_shmem[o_base_idx + 1u]), f32(o_shmem[o_base_idx + 2u]), f32(o_shmem[o_base_idx + 3u])); - for (var kv_idx : u32 = 0u; kv_idx < KV_TILE; kv_idx += 1u) { - let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); - let v_row = kv_tile + kv_idx; - if (v_row >= params.seq_len_kv) { - continue; - } - // let v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; - let v_idx = v_head_offset + v_row * params.stride_v1 + elem_base; - let v4 = vec4(V[v_idx/4u]); - - acc += p * v4; - } - // todo: write acc back to o_shmem - o_shmem[o_base_idx] = f16(acc.x); - o_shmem[o_base_idx + 1u] = f16(acc.y); - o_shmem[o_base_idx + 2u] = f16(acc.z); - o_shmem[o_base_idx + 3u] = f16(acc.w); - } - } - - workgroupBarrier(); - } - - -#ifdef SINKS - // add sinks (applied once after processing all KV tiles) - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - // no need to process rows beyond seq_len_q - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - var prev_max = row_max_shmem[q_tile_row]; - - // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum - let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); - let new_max = subgroupMax(max(prev_max, sink_val)); - let max_exp = exp(prev_max - new_max); - let sink_exp = exp(sink_val - new_max); - - let sink_exp_sum = subgroupAdd(sink_exp); - - if (sg_inv_id == 0) { - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - let val = f32(o_shmem[idx]) * max_exp; - o_shmem[idx] = f16(val); - } - } - workgroupBarrier(); -#endif - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { break; } - - let exp_sum = exp_sum_shmem[q_tile_row]; - let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); - - let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride; - - for (var elem_base = sg_inv_id * 4u; - elem_base < HEAD_DIM_V; - elem_base += subgroup_size * 4u) { - - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - - let v = vec4( - f32(o_shmem[i0]) * scale, - f32(o_shmem[i1]) * scale, - f32(o_shmem[i2]) * scale, - f32(o_shmem[i3]) * scale - ); - - let dst_vec_index: u32 = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = v; - } - } -} From f8e317c429ecf8e423cb9406d12989993f20cb6b Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 3 Mar 2026 12:55:25 -0800 Subject: [PATCH 05/34] remove unused helper functions --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 232 --------------------------- 1 file changed, 232 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f5233645ca85..94991f1594f6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -817,238 +817,6 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & { { wg_x, wg_y } }); } -struct webgpu_buffer_clear_op { - wgpu::Buffer buffer; - uint64_t offset; - uint64_t size; -}; - -struct webgpu_buffer_copy_op { - wgpu::Buffer src; - uint64_t src_offset; - wgpu::Buffer dst; - uint64_t dst_offset; - uint64_t size; -}; - -static webgpu_command ggml_backend_webgpu_build_with_pre_ops( - webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, - webgpu_pipeline & pipeline, - const std::vector & params, - const std::vector bind_group_entries, - const std::vector & clear_ops, - const std::vector & copy_ops, - std::vector retained_buffers, - uint32_t wg_x, - uint32_t wg_y = 1) { - webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); - - ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); - uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); - for (size_t i = 0; i < params.size(); i++) { - _params[i] = params[i]; - } - params_bufs.host_buf.Unmap(); - - std::vector entries = bind_group_entries; - uint32_t params_binding_num = 0; - for (const auto & entry : entries) { - if (entry.binding >= params_binding_num) { - params_binding_num = entry.binding + 1; - } - } - entries.push_back({ .binding = params_binding_num, - .buffer = params_bufs.dev_buf, - .offset = 0, - .size = params_bufs.dev_buf.GetSize() }); - - wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = entries.size(); - bind_group_desc.entries = entries.data(); - bind_group_desc.label = pipeline.name.c_str(); - wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); - - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); - - for (const auto & op : clear_ops) { - encoder.ClearBuffer(op.buffer, op.offset, op.size); - } - for (const auto & op : copy_ops) { - encoder.CopyBufferToBuffer(op.src, op.src_offset, op.dst, op.dst_offset, op.size); - } - -#ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); - } - - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); -#else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); -#endif - pass.SetPipeline(pipeline.pipeline); - pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups(wg_x, wg_y, 1); - pass.End(); - -#ifdef GGML_WEBGPU_GPU_PROFILE - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); -#endif - - webgpu_command result = {}; - result.commands = encoder.Finish(); - result.params_bufs = { params_bufs }; - result.retained_buffers = std::move(retained_buffers); -#ifdef GGML_WEBGPU_GPU_PROFILE - result.timestamp_query_bufs = ts_bufs; - result.pipeline_name = pipeline.name; -#endif - return result; -} - -static webgpu_command ggml_backend_webgpu_build_multi_with_pre_ops( - webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, - const std::vector & pipelines, - const std::vector> & params_list, - const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list, - const std::vector & clear_ops, - const std::vector & copy_ops, - std::vector retained_buffers, - const std::optional & set_rows_error_bufs = std::nullopt, - bool split_passes = false) { - GGML_ASSERT(pipelines.size() == params_list.size()); - GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); - GGML_ASSERT(pipelines.size() == workgroups_list.size()); - - std::vector params_bufs_list; - std::vector bind_groups; - - for (size_t i = 0; i < pipelines.size(); i++) { - webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); - - ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, - params_bufs.host_buf.GetSize()); - uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); - for (size_t j = 0; j < params_list[i].size(); j++) { - _params[j] = params_list[i][j]; - } - params_bufs.host_buf.Unmap(); - - std::vector entries = bind_group_entries_list[i]; - uint32_t params_binding_num = 0; - for (const auto & entry : entries) { - if (entry.binding >= params_binding_num) { - params_binding_num = entry.binding + 1; - } - } - entries.push_back({ .binding = params_binding_num, - .buffer = params_bufs.dev_buf, - .offset = 0, - .size = params_bufs.dev_buf.GetSize() }); - - wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = entries.size(); - bind_group_desc.entries = entries.data(); - bind_group_desc.label = pipelines[i].name.c_str(); - bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); - - params_bufs_list.push_back(params_bufs); - } - - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (const auto & params_bufs : params_bufs_list) { - encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); - } - - for (const auto & op : clear_ops) { - encoder.ClearBuffer(op.buffer, op.offset, op.size); - } - for (const auto & op : copy_ops) { - encoder.CopyBufferToBuffer(op.src, op.src_offset, op.dst, op.dst_offset, op.size); - } - - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. - if (set_rows_error_bufs) { - encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, - set_rows_error_bufs->host_buf.GetSize()); - } - - bool profile_pass = false; -#ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = {}; - if (!split_passes) { - ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); - } - profile_pass = true; - } -#endif -#ifndef GGML_WEBGPU_GPU_PROFILE - GGML_UNUSED(profile_pass); -#endif - - if (split_passes) { - for (size_t i = 0; i < pipelines.size(); i++) { - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); - pass.End(); - } - } else { -#ifdef GGML_WEBGPU_GPU_PROFILE - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); -#else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); -#endif - for (size_t i = 0; i < pipelines.size(); i++) { - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); - } - pass.End(); - } - -#ifdef GGML_WEBGPU_GPU_PROFILE - if (profile_pass) { - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); - } -#endif - - wgpu::CommandBuffer commands = encoder.Finish(); - webgpu_command result = {}; - result.commands = commands; - result.params_bufs = params_bufs_list; - result.set_rows_error_bufs = set_rows_error_bufs; - result.retained_buffers = std::move(retained_buffers); -#ifdef GGML_WEBGPU_GPU_PROFILE - if (profile_pass) { - result.timestamp_query_bufs = ts_bufs; - result.has_timestamp_query = true; - } - result.pipeline_name = pipelines.front().name; -#endif - return result; -} - static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, wgpu::Buffer & buf, uint32_t value, From 52709dd48d3df0563d538361864c756513e84f49 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 3 Mar 2026 21:43:20 -0800 Subject: [PATCH 06/34] add comments --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 3 +-- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 11 +++++++++++ .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 19 +++++++++++++++++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f80247bc8954..e5139c895ac6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -519,7 +519,6 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); // Add chosen Q/KV tile sizes. - // Mirror Metal's vec path for split kernels: 1 query per workgroup and 32 KV cache values per subgroup. uint32_t q_tile = context.sg_mat_m; uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); @@ -542,7 +541,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( uint32_t wg_size = 0; if (context.key.use_vec_split) { - // Keep vec-split to a single subgroup; aligns lane mapping with Metal's vec kernel. + // Keep vec-split to a single subgroup to avoid complexity in the reduction. wg_size = context.max_subgroup_size; } else { wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index d73b2207f320..afffc8f78fa7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -8,7 +8,10 @@ struct Params { offset_mask: u32, seq_len_q: u32, seq_len_kv: u32, + // plane b base = offset_mask + b * stride_mask3. stride_mask3: u32, + // Number of KV blocks and Q blocks per batch. + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). nblk0: u32, nblk1: u32, }; @@ -26,6 +29,9 @@ var wg_any: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3) { + // Dispatch mapping: + // - x indexes KV blocks + // - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk let kv_blk = wg_id.x; let y = wg_id.y; let q_blk = y % params.nblk1; @@ -40,6 +46,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3; + // We keep min/max to classify: + // - fully masked (max <= MASK_MIN) + // - all-zero mask (min == 0 && max == 0) + // - mixed/general mask var local_min = MASK_MAX; var local_max = -MASK_MAX; var local_any = 0u; @@ -67,6 +77,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, wg_any[local_id.x] = local_any; workgroupBarrier(); + // Thread 0 writes one state per block. if (local_id.x == 0u) { var mmin = wg_min[0]; var mmax = wg_max[0]; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index 00524ad1db62..b8b8bf5a72f7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -6,10 +6,14 @@ enable subgroups; #define WG_SIZE 128 struct Params { + // Total rows to reduce: nrows = batch * n_heads * seq_len_q. nrows: u32, seq_len_q: u32, n_heads: u32, + // Number of split workgroups used in the vec-split pass. + // Each split contributes one partial (o, l, m) per row. nwg: u32, + // Bases into tmp for partial output vectors and partial stats. tmp_data_base: u32, tmp_stats_base: u32, }; @@ -26,11 +30,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { + // One workgroup reduces one logical output row rid. let rid = wg_id.x; if (rid >= params.nrows) { return; } + // Decode flattened row id back to (batch, head, q_row). let rows_per_batch = params.n_heads * params.seq_len_q; let batch_idx = rid / rows_per_batch; let rem = rid % rows_per_batch; @@ -41,22 +47,30 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let dst3_stride = dst2_stride * params.seq_len_q; let row_base = batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; - // Mirror Metal's vec-reduce mapping: each subgroup lane corresponds to one split index. - // This requires params.nwg <= subgroup_size. + // Each subgroup lane corresponds to one split index g in [0, nwg). + // This kernel requires params.nwg <= subgroup_size. let lane = sg_inv_id; if (params.nwg > subgroup_size) { return; } + // Load split stats for this row: + // si = l_g (exp sum), mi = m_g (row max) from split g. let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); let active_lane = lane < params.nwg; let si = select(0.0, tmp[stats_base + 2u * lane + 0u], active_lane); let mi = select(FLOAT_MIN, tmp[stats_base + 2u * lane + 1u], active_lane); + + // Merge split softmax normalizers: + // m = max_g m_g + // l = sum_g l_g * exp(m_g - m) let m = subgroupMax(mi); let ms = select(0.0, exp(mi - m), active_lane); let s = subgroupAdd(si * ms); let inv_s = select(0.0, 1.0 / s, s != 0.0); + // Merge partial output vectors: + // O = (sum_g O_g * exp(m_g - m)) / l let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { var weighted = vec4(0.0, 0.0, 0.0, 0.0); @@ -70,6 +84,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_z = subgroupAdd(weighted.z); let sum_w = subgroupAdd(weighted.w); + // Lane 0 writes the final normalized vec4 chunk. if (lane == 0u) { let dst_vec_index = (row_base + elem_base) >> 2u; dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; From df6ef45af1683ace7e8d3f7905db8a8cbd987291 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 3 Mar 2026 22:11:03 -0800 Subject: [PATCH 07/34] remove pad path --- .../wgsl-shaders/flash_attn_vec_split.wgsl | 139 ++---------------- 1 file changed, 12 insertions(+), 127 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 0f36f5af5edd..57bc96fc54fe 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -91,15 +91,6 @@ struct Params { m0: f32, m1: f32, -#ifdef PAD - offset_pad: u32, - pad_k_base: u32, - pad_v_base: u32, - pad_m_base: u32, - ncpsg: u32, - nqptg: u32, -#endif - #ifdef BLK blk_base: u32, blk_nblk0: u32, @@ -114,48 +105,6 @@ struct Params { @group(0) @binding(0) var Q: array; @group(0) @binding(1) var K: array; @group(0) @binding(2) var V: array>; - -#ifdef PAD -@group(0) @binding(3) var pad: array; -#endif - -#ifdef PAD -#if defined(MASK) && defined(SINKS) -@group(0) @binding(4) var mask: array; -@group(0) @binding(5) var sinks: array; -#ifdef BLK -#define BLK_BINDING 6 -#define TMP_BINDING 7 -#define DST_BINDING 8 -#define PARAMS_BINDING 9 -#else -#define TMP_BINDING 6 -#define DST_BINDING 7 -#define PARAMS_BINDING 8 -#endif -#elif defined(MASK) -@group(0) @binding(4) var mask: array; -#ifdef BLK -#define BLK_BINDING 5 -#define TMP_BINDING 6 -#define DST_BINDING 7 -#define PARAMS_BINDING 8 -#else -#define TMP_BINDING 5 -#define DST_BINDING 6 -#define PARAMS_BINDING 7 -#endif -#elif defined(SINKS) -@group(0) @binding(4) var sinks: array; -#define TMP_BINDING 5 -#define DST_BINDING 6 -#define PARAMS_BINDING 7 -#else -#define TMP_BINDING 4 -#define DST_BINDING 5 -#define PARAMS_BINDING 6 -#endif -#else #if defined(MASK) && defined(SINKS) @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @@ -191,7 +140,6 @@ struct Params { #define DST_BINDING 4 #define PARAMS_BINDING 5 #endif -#endif #ifdef BLK @group(0) @binding(BLK_BINDING) var blk: array; @@ -323,11 +271,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { -#ifdef PAD - let tail = params.seq_len_kv % params.ncpsg; - let use_pad_tile = (tail != 0u) && (kv_tile + params.ncpsg >= params.seq_len_kv); - let kv_plane = k_head_idx + batch_idx * (params.n_heads / params.q_per_kv); -#endif #ifdef BLK let q_blk = q_row_start / Q_TILE; let kv_blk = kv_tile / KV_TILE; @@ -429,53 +372,24 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (kv_valid) { for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { let q_off = q_tile_row * HEAD_DIM_QK + i * 4u; - // let k_off = (kv_idx * HEAD_DIM_QK) + i * 4u; - var idx: u32 = 0u; -#ifdef PAD - if (use_pad_tile) { - idx = params.offset_pad + params.pad_k_base + - kv_plane * params.stride_k1 * params.ncpsg + - kv_idx * params.stride_k1 + i * 4u; - } else { - idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); - } -#else - idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); -#endif + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); let qv = vec4(f32(q_shmem[q_off]), f32(q_shmem[q_off + 1u]), f32(q_shmem[q_off + 2u]), f32(q_shmem[q_off + 3u])); - // let kv = vec4(f32(kv_shmem[k_off]), f32(kv_shmem[k_off + 1u]), f32(kv_shmem[k_off + 2u]), f32(kv_shmem[k_off + 3u])); - var kv: vec4; -#ifdef PAD - if (use_pad_tile) { - kv = vec4(f32(pad[idx]), f32(pad[idx + 1u]), f32(pad[idx + 2u]), f32(pad[idx + 3u])); - } else { - kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); - } -#else - kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); -#endif + let kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); partial_sum += dot(qv, kv); } } - // Match Metal vec reduction pattern: reduce within each ty stripe using subgroup shuffles. + // Reduce along tx lanes inside each ty stripe. var sum = partial_sum; - if (num_of_threads <= 1u) { - sum += subgroupShuffleDown(sum, 16u); - } - if (num_of_threads <= 2u) { - sum += subgroupShuffleDown(sum, 8u); - } - if (num_of_threads <= 4u) { - sum += subgroupShuffleDown(sum, 4u); - } - if (num_of_threads <= 8u) { - sum += subgroupShuffleDown(sum, 2u); - } - if (num_of_threads <= 16u) { - sum += subgroupShuffleDown(sum, 1u); + var delta = num_of_threads >> 1u; + loop { + if (delta == 0u) { + break; + } + sum += subgroupShuffleDown(sum, delta); + delta = delta >> 1u; } let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); @@ -497,24 +411,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let mask_col = elem_idx % KV_TILE; let global_q_row = q_row_start + mask_row; let global_k_col = kv_tile + mask_col; -#ifdef PAD - if (use_pad_tile) { - let mask_batch_idx = select(0u, batch_idx, params.stride_mask3 > 0u); - let pad_mask_plane_base = params.offset_pad + params.pad_m_base + - mask_batch_idx * params.seq_len_q * params.ncpsg; - let mask_in_bounds = global_q_row < params.seq_len_q; - let mask_idx = pad_mask_plane_base + global_q_row * params.ncpsg + mask_col; - mask_shmem[elem_idx] = f16(select(0.0, pad[mask_idx], mask_in_bounds)); - } else { - let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; - let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; - mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); - } -#else let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); -#endif } } #else @@ -661,22 +560,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); - var v_idx: u32 = 0u; - var v4: vec4; -#ifdef PAD - if (use_pad_tile) { - v_idx = params.offset_pad + params.pad_v_base + - kv_plane * params.stride_v1 * params.ncpsg + - kv_idx * params.stride_v1 + vec_col * 4u; - v4 = vec4(f32(pad[v_idx]), f32(pad[v_idx + 1u]), f32(pad[v_idx + 2u]), f32(pad[v_idx + 3u])); - } else { - v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; - v4 = vec4(V[v_idx / 4u]); - } -#else - v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; - v4 = vec4(V[v_idx / 4u]); -#endif + let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + let v4 = vec4(V[v_idx / 4u]); lo += p * v4; } From 838306f4bdcfba2aef9e1d00a982dc0bdb4f7ee7 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Fri, 6 Mar 2026 16:37:16 -0800 Subject: [PATCH 08/34] ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 11 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 108 +++++++------ .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 31 ++-- .../wgsl-shaders/flash_attn_vec_split.wgsl | 151 +++++++++++------- 4 files changed, 165 insertions(+), 136 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e5139c895ac6..d0ad4266006a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -524,9 +524,12 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); if (context.key.use_vec_split) { q_tile = 1; - kv_tile = 32; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context); + kv_tile = std::max(context.sg_mat_n, std::min(32u, max_kv_tile)); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; GGML_ASSERT(kv_tile % context.sg_mat_n == 0); - GGML_ASSERT(kv_tile <= ggml_webgpu_flash_attn_max_kv_tile(context)); + GGML_ASSERT(kv_tile <= max_kv_tile); + defines.push_back("VEC_NE=1"); } if (context.key.kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); @@ -541,8 +544,8 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( uint32_t wg_size = 0; if (context.key.use_vec_split) { - // Keep vec-split to a single subgroup to avoid complexity in the reduction. - wg_size = context.max_subgroup_size; + // Keep vec-split to exactly one subgroup to preserve thread mapping. + wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); } else { wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 94991f1594f6..4f12590c8d92 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1395,11 +1395,11 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - // Match Metal's vec-kernel shape heuristic. const bool use_vec = (Q->ne[1] < 20) && - (Q->ne[0] % 32 == 0) && - (V->ne[0] % 4 == 0) && - (K->type == GGML_TYPE_F16); + (Q->ne[0] % 32 == 0) && + (V->ne[0] % 4 == 0) && + (K->type == GGML_TYPE_F16) && + (Q->ne[2] == K->ne[2]); const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); @@ -1450,14 +1450,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - uint32_t vec_nwg = 1u; - if (use_vec_split) { - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * vec_nwg * kv_span) < (uint64_t) K->ne[1] && vec_nwg < vec_nwg_cap) { - vec_nwg <<= 1; - } - vec_nwg = std::min(vec_nwg, vec_nwg_cap); - } + uint32_t vec_nwg = use_vec_split ? 1u : vec_nwg_cap; bool have_blk_buf = false; wgpu::Buffer blk_buf = {}; @@ -1470,6 +1463,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const uint32_t nwg = vec_nwg; GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + // For a single split workgroup there is nothing to merge. + // Let vec split write final dst directly and skip reduce. + const bool use_vec_reduce = nwg > 1u; const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; const uint64_t tmp_stats_base = tmp_data_elems; @@ -1481,7 +1477,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, GGML_ASSERT(nrows <= UINT32_MAX); wgpu::Buffer tmp_buf; - ggml_webgpu_create_buffer(ctx->global_ctx->device, tmp_buf, tmp_size_bytes, wgpu::BufferUsage::Storage, + ggml_webgpu_create_buffer(ctx->global_ctx->device, tmp_buf, tmp_size_bytes, + wgpu::BufferUsage::Storage, "flash_attn_vec_tmp"); webgpu_pipeline blk_pipeline; @@ -1591,44 +1588,49 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - const uint32_t reduce_wg_size = std::max( - 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - const ggml_webgpu_flash_attn_vec_reduce_pipeline_key reduce_key = { - .head_dim_v = (uint32_t) V->ne[0], - .wg_size = reduce_wg_size, - }; webgpu_pipeline reduce_pipeline; - auto reduce_it = ctx->flash_attn_vec_reduce_pipelines.find(reduce_key); - if (reduce_it != ctx->flash_attn_vec_reduce_pipelines.end()) { - reduce_pipeline = reduce_it->second; - } else { - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = reduce_key, - .max_wg_size = reduce_wg_size, + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key reduce_key = { + .head_dim_v = (uint32_t) V->ne[0], + .wg_size = reduce_wg_size, }; - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( - ctx->p, wgsl_flash_attn_vec_reduce, reduce_shader_ctx); - reduce_pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), - processed.variant.c_str()); - ctx->flash_attn_vec_reduce_pipelines.emplace(reduce_key, reduce_pipeline); - } + auto reduce_it = ctx->flash_attn_vec_reduce_pipelines.find(reduce_key); + if (reduce_it != ctx->flash_attn_vec_reduce_pipelines.end()) { + reduce_pipeline = reduce_it->second; + } else { + ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { + .key = reduce_key, + .max_wg_size = reduce_wg_size, + }; + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( + ctx->p, wgsl_flash_attn_vec_reduce, reduce_shader_ctx); + reduce_pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), + processed.variant.c_str()); + ctx->flash_attn_vec_reduce_pipelines.emplace(reduce_key, reduce_pipeline); + } - std::vector reduce_params = { - (uint32_t) nrows, // nrows - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) Q->ne[2], // n_heads - nwg, // nwg - 0u, // tmp_data_base - (uint32_t) tmp_stats_base, // tmp_stats_base - }; + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; - std::vector reduce_entries = { - { .binding = 0, .buffer = tmp_buf, .offset = 0, .size = tmp_size_bytes }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; + reduce_entries = { + { .binding = 0, .buffer = tmp_buf, .offset = 0, .size = tmp_size_bytes }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + } const uint64_t split_wg_total = (uint64_t) wg_x * nwg; GGML_ASSERT(split_wg_total <= UINT32_MAX); @@ -1647,10 +1649,12 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, params_list.push_back(std::move(split_params)); entries_list.push_back(std::move(split_entries)); workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); - pipelines.push_back(reduce_pipeline); - params_list.push_back(std::move(reduce_params)); - entries_list.push_back(std::move(reduce_entries)); - workgroups_list.push_back({ (uint32_t) nrows, 1u }); + if (use_vec_reduce) { + pipelines.push_back(reduce_pipeline); + params_list.push_back(std::move(reduce_params)); + entries_list.push_back(std::move(reduce_entries)); + workgroups_list.push_back({ (uint32_t) nrows, 1u }); + } const bool split_passes = use_blk; @@ -3183,7 +3187,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; #endif - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; // Initialize device diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index b8b8bf5a72f7..6b38f581bb25 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -6,20 +6,17 @@ enable subgroups; #define WG_SIZE 128 struct Params { - // Total rows to reduce: nrows = batch * n_heads * seq_len_q. nrows: u32, seq_len_q: u32, n_heads: u32, - // Number of split workgroups used in the vec-split pass. - // Each split contributes one partial (o, l, m) per row. + offset_dst: u32, nwg: u32, - // Bases into tmp for partial output vectors and partial stats. tmp_data_base: u32, tmp_stats_base: u32, }; @group(0) @binding(0) var tmp: array; -@group(0) @binding(1) var dst: array>; +@group(0) @binding(1) var dst: array; @group(0) @binding(2) var params: Params; const FLOAT_MIN: f32 = -1.0e9; @@ -30,13 +27,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { - // One workgroup reduces one logical output row rid. let rid = wg_id.x; if (rid >= params.nrows) { return; } - // Decode flattened row id back to (batch, head, q_row). let rows_per_batch = params.n_heads * params.seq_len_q; let batch_idx = rid / rows_per_batch; let rem = rid % rows_per_batch; @@ -45,32 +40,22 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let dst2_stride = HEAD_DIM_V * params.n_heads; let dst3_stride = dst2_stride * params.seq_len_q; - let row_base = batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; + let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; - // Each subgroup lane corresponds to one split index g in [0, nwg). - // This kernel requires params.nwg <= subgroup_size. let lane = sg_inv_id; if (params.nwg > subgroup_size) { return; } - // Load split stats for this row: - // si = l_g (exp sum), mi = m_g (row max) from split g. let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); let active_lane = lane < params.nwg; - let si = select(0.0, tmp[stats_base + 2u * lane + 0u], active_lane); + let si = select(0.0, tmp[stats_base + 2u * lane + 0u], active_lane); let mi = select(FLOAT_MIN, tmp[stats_base + 2u * lane + 1u], active_lane); - - // Merge split softmax normalizers: - // m = max_g m_g - // l = sum_g l_g * exp(m_g - m) let m = subgroupMax(mi); let ms = select(0.0, exp(mi - m), active_lane); let s = subgroupAdd(si * ms); let inv_s = select(0.0, 1.0 / s, s != 0.0); - // Merge partial output vectors: - // O = (sum_g O_g * exp(m_g - m)) / l let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { var weighted = vec4(0.0, 0.0, 0.0, 0.0); @@ -84,10 +69,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_z = subgroupAdd(weighted.z); let sum_w = subgroupAdd(weighted.w); - // Lane 0 writes the final normalized vec4 chunk. if (lane == 0u) { - let dst_vec_index = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + let dst_base = row_base + elem_base; + dst[dst_base + 0u] = sum_x * inv_s; + dst[dst_base + 1u] = sum_y * inv_s; + dst[dst_base + 2u] = sum_z * inv_s; + dst[dst_base + 3u] = sum_w * inv_s; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 57bc96fc54fe..a500fe8a733f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -10,37 +10,32 @@ enable chromium_experimental_subgroup_matrix; #define KV_TYPE f16 #endif -// Default values #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 -// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN -// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. + #define SG_MAT_M 8 #define SG_MAT_N 8 #define SG_MAT_K 8 -// Each workgroup processes one subgroup matrix of Q rows #define Q_TILE SG_MAT_M #define KV_TILE 16 #define WG_SIZE 64 +#ifndef VEC_NE +#define VEC_NE 4u +#endif -// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. #define KV_BLOCKS (KV_TILE / SG_MAT_N) -// Quantization constants/helpers #define BLOCK_SIZE 32 #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) #define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread #if defined(KV_Q4_0) #define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights #define F16_PER_BLOCK 9 #define WEIGHTS_PER_F16 4 #elif defined(KV_Q8_0) #define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights #define F16_PER_BLOCK 17 #define WEIGHTS_PER_F16 2 #endif @@ -104,7 +99,7 @@ struct Params { @group(0) @binding(0) var Q: array; @group(0) @binding(1) var K: array; -@group(0) @binding(2) var V: array>; +@group(0) @binding(2) var V: array; #if defined(MASK) && defined(SINKS) @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @@ -151,7 +146,6 @@ struct Params { // Just a very small float value. const FLOAT_MIN: f32 = -1.0e9; -// The number of Q rows processed per workgroup var q_shmem: array; #ifndef KV_DIRECT @@ -178,30 +172,25 @@ var exp_sum_shmem: array; var blk_state_wg: u32; fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { - var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, - kv_idx < KV_TILE); + var v = FLOAT_MIN; + if (kv_idx < KV_TILE) { + v = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale; + } #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); - // Common fast path (mask only): avoid extra mul when bias scaling is disabled. + var mask_val = 0.0; + if (kv_idx < KV_TILE) { + mask_val = f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]); + } v += select(mask_val, slope * mask_val, has_bias); } #endif return v; } -fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { - return (*buf)[scalar_index >> 2u]; -} - -fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { - return (*buf)[scalar_index >> 2u]; -} - @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, @@ -258,18 +247,20 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let has_bias = params.max_bias > 0.0; let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); - // load q tile into shared memory + // Load Q tile once and keep it in f16 to match scalar/Metal precision path. for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { let q_row = elem_idx / HEAD_DIM_QK; let q_col = elem_idx % HEAD_DIM_QK; - let head_q_row = q_row_start + q_row; - let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + let global_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + global_q_row * params.stride_q1; q_shmem[elem_idx] = f16(select( 0.0, Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + global_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); } + workgroupBarrier(); + for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { #ifdef BLK let q_blk = q_row_start / Q_TILE; @@ -286,9 +277,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); let blk_state = blk_state_wg; let skip_tile = blk_state == 0u; - // clear inter_shmem to ensure zero-initialized accumulators for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = 0.0; + inter_shmem[elem_idx] = f16(0.0); } // load k tile into shared memory @@ -354,8 +344,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // TODO: this loop seems to be the current largest bottleneck // this bracket exists to scope the lifetime of variables, reducing register pressure if (!skip_tile) { - // vectorization - let num_of_threads = subgroup_size / 4u; + let num_of_threads = subgroup_size / VEC_NE; let tx = sg_inv_id % num_of_threads; let ty = sg_inv_id / num_of_threads; for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { @@ -363,33 +352,43 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_q_row >= params.seq_len_q) { continue; } + let local_q_row_offset = q_tile_row * HEAD_DIM_QK; - for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += 4u) { + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { let kv_idx = kv_base + ty; var partial_sum: f32 = 0.0; let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; if (kv_valid) { for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { - let q_off = q_tile_row * HEAD_DIM_QK + i * 4u; + let q_off = local_q_row_offset + i * 4u; let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); - let qv = vec4(f32(q_shmem[q_off]), f32(q_shmem[q_off + 1u]), f32(q_shmem[q_off + 2u]), f32(q_shmem[q_off + 3u])); - let kv = vec4(f32(K[idx]), f32(K[idx + 1u]), f32(K[idx + 2u]), f32(K[idx + 3u])); + let qv = vec4(f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); + let kv = vec4(f32(K[idx + 0u]), + f32(K[idx + 1u]), + f32(K[idx + 2u]), + f32(K[idx + 3u])); partial_sum += dot(qv, kv); } } - // Reduce along tx lanes inside each ty stripe. var sum = partial_sum; - var delta = num_of_threads >> 1u; + // Reduce over tx lanes (NL) for this ty stripe. + var tx_delta = num_of_threads >> 1u; loop { - if (delta == 0u) { + if (tx_delta == 0u) { break; } - sum += subgroupShuffleDown(sum, delta); - delta = delta >> 1u; + let sh = subgroupShuffleDown(sum, tx_delta); + if (tx < tx_delta) { + sum += sh; + } + tx_delta >>= 1u; } let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); @@ -436,7 +435,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // pass 1: compute final max across the full KV tile in chunks for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask), + kv_valid); final_max = subgroupMax(max(final_max, softmax_term)); } @@ -543,7 +545,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (!skip_tile) { // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V across the full KV tile - let ne_lanes = 4u; + let ne_lanes : u32 = VEC_NE; let nl_lanes = max(1u, subgroup_size / ne_lanes); let tx_pv = sg_inv_id % nl_lanes; let ty_pv = sg_inv_id / nl_lanes; @@ -561,25 +563,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; - let v4 = vec4(V[v_idx / 4u]); + let v4 = vec4(f32(V[v_idx + 0u]), + f32(V[v_idx + 1u]), + f32(V[v_idx + 2u]), + f32(V[v_idx + 3u])); lo += p * v4; } - // Match Metal's vec PV reduction: reduce across ty lanes. var lo_x = lo.x; var lo_y = lo.y; var lo_z = lo.z; var lo_w = lo.w; - var delta = nl_lanes * (ne_lanes >> 1u); + // Reduce over ty lanes (NE) for this tx lane. + var ty_delta = ne_lanes >> 1u; loop { - if (delta == 0u || delta < nl_lanes) { + if (ty_delta == 0u) { break; } - lo_x += subgroupShuffleDown(lo_x, delta); - lo_y += subgroupShuffleDown(lo_y, delta); - lo_z += subgroupShuffleDown(lo_z, delta); - lo_w += subgroupShuffleDown(lo_w, delta); - delta = delta >> 1u; + let lane_delta = ty_delta * nl_lanes; + let shx = subgroupShuffleDown(lo_x, lane_delta); + let shy = subgroupShuffleDown(lo_y, lane_delta); + let shz = subgroupShuffleDown(lo_z, lane_delta); + let shw = subgroupShuffleDown(lo_w, lane_delta); + if (ty_pv < ty_delta) { + lo_x += shx; + lo_y += shy; + lo_z += shz; + lo_w += shw; + } + ty_delta >>= 1u; } if (ty_pv == 0u) { @@ -604,7 +616,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - // no need to process rows beyond seq_len_q let global_q_row = q_row_start + q_tile_row; if (global_q_row >= params.seq_len_q) { break; @@ -621,14 +632,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sink_exp_sum = subgroupAdd(sink_exp); if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = new_max; exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; } - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - let val = f32(o_shmem[idx]) * max_exp; - o_shmem[idx] = f16(val); - } + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp); + } } workgroupBarrier(); } @@ -665,5 +676,29 @@ fn main(@builtin(workgroup_id) wg_id: vec3, tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; } + + if (params.nwg == 1u) { + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base: u32 = + params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V; + + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } + } } } From d61ec8f203a1406cd6f7f29fbe2a27efe1f8a23c Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Fri, 6 Mar 2026 16:39:50 -0800 Subject: [PATCH 09/34] change back to vec4 --- .../ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl | 9 +++------ .../ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl | 7 ++----- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index 6b38f581bb25..43dbf7d3aa08 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -16,7 +16,7 @@ struct Params { }; @group(0) @binding(0) var tmp: array; -@group(0) @binding(1) var dst: array; +@group(0) @binding(1) var dst: array>; @group(0) @binding(2) var params: Params; const FLOAT_MIN: f32 = -1.0e9; @@ -70,11 +70,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_w = subgroupAdd(weighted.w); if (lane == 0u) { - let dst_base = row_base + elem_base; - dst[dst_base + 0u] = sum_x * inv_s; - dst[dst_base + 1u] = sum_y * inv_s; - dst[dst_base + 2u] = sum_z * inv_s; - dst[dst_base + 3u] = sum_w * inv_s; + let dst_vec_index = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index a500fe8a733f..c5b7372ed30c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -99,7 +99,7 @@ struct Params { @group(0) @binding(0) var Q: array; @group(0) @binding(1) var K: array; -@group(0) @binding(2) var V: array; +@group(0) @binding(2) var V: array>; #if defined(MASK) && defined(SINKS) @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @@ -563,10 +563,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; - let v4 = vec4(f32(V[v_idx + 0u]), - f32(V[v_idx + 1u]), - f32(V[v_idx + 2u]), - f32(V[v_idx + 3u])); + let v4 = vec4(V[v_idx >> 2u]); lo += p * v4; } From 042a1a5697a6c84400d7fee913dff79da785567f Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Fri, 6 Mar 2026 17:09:11 -0800 Subject: [PATCH 10/34] enable multi split --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4f12590c8d92..0f0adb142e92 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1450,7 +1450,14 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - uint32_t vec_nwg = use_vec_split ? 1u : vec_nwg_cap; + uint32_t vec_nwg = 1u; + if (use_vec_split) { + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * vec_nwg * kv_span) < (uint64_t) K->ne[1] && vec_nwg < vec_nwg_cap) { + vec_nwg <<= 1; + } + vec_nwg = std::min(vec_nwg, vec_nwg_cap); + } bool have_blk_buf = false; wgpu::Buffer blk_buf = {}; @@ -1656,7 +1663,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, workgroups_list.push_back({ (uint32_t) nrows, 1u }); } - const bool split_passes = use_blk; + const bool split_passes = use_blk || use_vec_reduce; std::vector retained_buffers = { tmp_buf }; if (use_blk) { From b61e63d8279d47e54bd4242297bc5405d1ccc194 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Fri, 6 Mar 2026 17:19:39 -0800 Subject: [PATCH 11/34] enable vec path when: - Q->ne[1] < 20 - Q->ne[0] % 32 == 0 - V->ne[0] % 4 == 0 - K->type == f16 --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 402 +++++++++++---------------- 1 file changed, 156 insertions(+), 246 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 0f0adb142e92..abf64ea87243 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" +#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -19,18 +20,11 @@ #include #include #include -#ifdef GGML_WEBGPU_GPU_PROFILE -# include -#endif -#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) -# include -#endif #include #include #include #include #include -#include #include #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) @@ -75,26 +69,23 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 +# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps #endif /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 96u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_NUM_PARAM_BUFS 48u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool -<<<<<<< HEAD -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) -======= #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_MAX_PARAM_BUFS_PER_CMD 4u ->>>>>>> 30923ffc9 (add vectorized flash attention) #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters +#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 // For operations which process a row in parallel, this seems like a reasonable // default @@ -127,9 +118,19 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); +struct webgpu_pool_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; +}; + +// The futures to wait on for a single queue submission +struct webgpu_submission_futures { + std::vector futures; +}; + // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { - std::vector free; + std::vector free; // The pool must be synchronized because // 1. The memset pool is shared globally by every ggml buffer, @@ -142,6 +143,7 @@ struct webgpu_buf_pool { size_t cur_pool_size; size_t max_pool_size; wgpu::Device device; + wgpu::BufferUsage host_buf_usage; wgpu::BufferUsage dev_buf_usage; size_t buf_size; bool should_grow; @@ -150,47 +152,53 @@ struct webgpu_buf_pool { int num_bufs, size_t buf_size, wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage, bool should_grow = false, size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; + this->max_pool_size = max_pool_size; + this->cur_pool_size = num_bufs; + this->device = device; + this->host_buf_usage = host_buf_usage; + this->dev_buf_usage = dev_buf_usage; + this->buf_size = buf_size; + this->should_grow = should_grow; for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back(dev_buf); + free.push_back({ host_buf, dev_buf }); } } - wgpu::Buffer alloc_bufs() { + webgpu_pool_bufs alloc_bufs() { std::unique_lock lock(mutex); if (!free.empty()) { - wgpu::Buffer buf = free.back(); + webgpu_pool_bufs bufs = free.back(); free.pop_back(); - return buf; + return bufs; } // Try growing the pool if no free buffers if (free.empty() && cur_pool_size < max_pool_size && should_grow) { cur_pool_size++; + wgpu::Buffer host_buf; wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - if (!dev_buf) { + if (!(host_buf && dev_buf)) { GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); } - return dev_buf; + return webgpu_pool_bufs{ host_buf, dev_buf }; } cv.wait(lock, [this] { return !free.empty(); }); - wgpu::Buffer buf = free.back(); + webgpu_pool_bufs bufs = free.back(); free.pop_back(); - return buf; + return bufs; } - void free_bufs(std::vector bufs) { + void free_bufs(std::vector bufs) { std::lock_guard lock(mutex); free.insert(free.end(), bufs.begin(), bufs.end()); cv.notify_all(); @@ -198,9 +206,12 @@ struct webgpu_buf_pool { void cleanup() { std::lock_guard lock(mutex); - for (auto & buf : free) { - if (buf) { - buf.Destroy(); + for (auto & bufs : free) { + if (bufs.host_buf) { + bufs.host_buf.Destroy(); + } + if (bufs.dev_buf) { + bufs.dev_buf.Destroy(); } } free.clear(); @@ -274,18 +285,12 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { -<<<<<<< HEAD - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; -======= uint32_t num_kernels; wgpu::CommandBuffer commands; std::vector params_bufs; std::optional set_rows_error_bufs; // Keep temporary resources alive until submitted work is complete. std::vector retained_buffers; ->>>>>>> 30923ffc9 (add vectorized flash attention) #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -361,13 +366,6 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; -struct webgpu_submission { - wgpu::FutureWaitInfo submit_done; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; -#endif -}; - // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -377,8 +375,7 @@ struct webgpu_context_struct { pre_wgsl::Preprocessor p; webgpu_buf_pool param_buf_pool; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; + webgpu_buf_pool set_rows_error_buf_pool; std::map> cpy_pipelines; // src_type, dst_type std::unordered_map @@ -478,104 +475,32 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ -static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { - switch (status) { - case wgpu::WaitStatus::Success: - return true; - case wgpu::WaitStatus::TimedOut: - if (allow_timeout) { - return false; - } - GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); - return false; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - return false; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - return false; - } -} - -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); -} - -static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures, - bool block) { - if (futures.empty()) { - return; - } - - uint64_t timeout_ms = block ? UINT64_MAX : 0; - if (block) { - while (!futures.empty()) { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } - } - } else { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } - } -} -#endif - // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & subs, - bool block = true) { +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, + std::vector & futures, + bool block = true) { // If we have too many in-flight submissions, wait on the oldest one first. - if (subs.empty()) { - return; - } - while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); -#endif - subs.erase(subs.begin()); - } - } - - if (subs.empty()) { - return; - } - - if (block) { - for (auto & sub : subs) { - while (!sub.submit_done.completed) { - auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX); - ggml_backend_webgpu_handle_wait_status(waitStatus); - } -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true); -#endif - } - subs.clear(); - } else { - // Poll each submit future once and remove completed submissions. - for (auto sub = subs.begin(); sub != subs.end();) { - auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); - ggml_backend_webgpu_handle_wait_status(waitStatus, true); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); - if (sub->submit_done.completed && sub->profile_futures.empty()) { -#else - if (sub->submit_done.completed) { -#endif - sub = subs.erase(sub); - } else { - ++sub; - } + uint64_t timeout_ms = block ? UINT64_MAX : 0; + while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { + ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); + futures.erase(futures.begin()); + } + size_t i = 0; + while (i < futures.size()) { + auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); + switch (waitStatus) { + case wgpu::WaitStatus::Success: + futures.erase(futures.begin() + i); + break; + case wgpu::WaitStatus::TimedOut: + i++; + break; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + break; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + break; } } } @@ -611,18 +536,14 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, - std::vector & commands, - webgpu_buf_pool & param_buf_pool) { +static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, + std::vector commands, + webgpu_buf_pool & param_buf_pool, + webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { std::vector command_buffers; -<<<<<<< HEAD - std::vector params_bufs; - webgpu_submission submission; -======= std::vector params_bufs; std::vector set_rows_error_bufs; std::vector retained_buffers; ->>>>>>> 30923ffc9 (add vectorized flash attention) #ifdef GGML_WEBGPU_GPU_PROFILE std::vector> pipeline_name_and_ts_bufs; #endif @@ -630,16 +551,15 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & for (const auto & command : commands) { command_buffers.push_back(command.commands); params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); -<<<<<<< HEAD -======= retained_buffers.insert(retained_buffers.end(), command.retained_buffers.begin(), command.retained_buffers.end()); if (command.set_rows_error_bufs) { set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); } ->>>>>>> 30923ffc9 (add vectorized flash attention) } ctx->queue.Submit(command_buffers.size(), command_buffers.data()); + std::vector futures; + wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, [¶m_buf_pool, params_bufs, retained_buffers](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { @@ -650,7 +570,27 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & // Free the staged buffers param_buf_pool.free_bufs(params_bufs); }); - submission.submit_done = { p_f }; + futures.push_back({ p_f }); + + for (const auto & bufs : set_rows_error_bufs) { + wgpu::Future f = bufs.host_buf.MapAsync( + wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + if (status != wgpu::MapAsyncStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); + } else { + const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + // We can't unmap in here due to WebGPU reentrancy limitations. + if (set_rows_error_buf_pool) { + set_rows_error_buf_pool->free_bufs({ bufs }); + } + } + }); + futures.push_back({ f }); + } #ifdef GGML_WEBGPU_GPU_PROFILE for (const auto & command : commands) { @@ -670,14 +610,14 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & // WebGPU timestamps are in ns; convert to ms double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; ctx->shader_gpu_time_ms[label] += elapsed_ms; + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); }); - submission.profile_futures.push_back({ f }); + futures.push_back({ f }); } #endif - return submission; + return { futures }; } static webgpu_command ggml_backend_webgpu_build_multi( @@ -686,29 +626,19 @@ static webgpu_command ggml_backend_webgpu_build_multi( const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, -<<<<<<< HEAD - const std::vector> & workgroups_list) { -======= const std::vector> & workgroups_list, const std::optional & set_rows_error_bufs = std::nullopt, bool split_passes = false) { ->>>>>>> 30923ffc9 (add vectorized flash attention) GGML_ASSERT(pipelines.size() == params_list.size()); GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); - std::vector params_bufs_list; - std::vector bind_groups; + std::vector params_bufs_list; + std::vector bind_groups; for (size_t i = 0; i < pipelines.size(); i++) { - wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); + webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); -<<<<<<< HEAD - std::vector entries = bind_group_entries_list[i]; - uint32_t params_binding_num = entries.size(); - entries.push_back( - { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); -======= std::vector entries = bind_group_entries_list[i]; // Bindings can be sparse (e.g. 0,1,2,4,5), so params must use max(binding)+1. uint32_t params_binding_num = 0; @@ -721,7 +651,6 @@ static webgpu_command ggml_backend_webgpu_build_multi( .buffer = params_bufs.dev_buf, .offset = 0, .size = params_bufs.dev_buf.GetSize() }); ->>>>>>> 30923ffc9 (add vectorized flash attention) wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -735,7 +664,14 @@ static webgpu_command ggml_backend_webgpu_build_multi( wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); for (size_t i = 0; i < params_bufs_list.size(); i++) { - ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); + ctx->queue.WriteBuffer(params_bufs_list[i].dev_buf, 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); + } + + // If there are SET_ROWS operations in this submission, copy their error + // buffers to the host. + if (set_rows_error_bufs) { + encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, + set_rows_error_bufs->host_buf.GetSize()); } bool profile_pass = false; @@ -790,6 +726,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( webgpu_command result = {}; result.commands = commands; result.params_bufs = params_bufs_list; + result.set_rows_error_bufs = set_rows_error_bufs; result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE if (profile_pass) { @@ -808,13 +745,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, - uint32_t wg_y = 1) { + uint32_t wg_y = 1, + std::optional set_rows_error_bufs = std::nullopt) { return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, { pipeline }, - { std::move(params) }, { std::move(bind_group_entries) }, - { { wg_x, wg_y } }); + { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -831,9 +768,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector commands = { command }; - std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; - ggml_backend_webgpu_wait(ctx, sub); + std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }, + ctx->memset_buf_pool) }; + ggml_backend_webgpu_wait(ctx, futures); } /** End WebGPU Actions */ @@ -1054,6 +991,14 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); + std::optional error_bufs = std::nullopt; + if (decisions->i64_idx) { + error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); + if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + error_bufs->host_buf.Unmap(); + } + } + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), @@ -1086,10 +1031,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, }; if (decisions->i64_idx) { - entries.push_back({ .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back( + { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); } uint32_t threads; @@ -1099,7 +1042,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, + error_bufs); } // Workgroup size is a common constant @@ -1270,18 +1214,17 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { - auto * decisions = static_cast(pipeline.context.get()); + auto decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { - auto * decisions = static_cast(pipeline.context.get()); + auto decisions = static_cast(pipeline.context.get()); // Fast-path tiled/subgroup calculations - uint32_t wg_m; - uint32_t wg_n; + uint32_t wg_m, wg_n; if (decisions->use_subgroup_matrix) { uint32_t wg_m_sg_tile = decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; @@ -1299,7 +1242,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else { // legacy - auto * decisions = static_cast(pipeline.context.get()); + auto decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); @@ -1396,10 +1339,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); const bool use_vec = (Q->ne[1] < 20) && - (Q->ne[0] % 32 == 0) && - (V->ne[0] % 4 == 0) && - (K->type == GGML_TYPE_F16) && - (Q->ne[2] == K->ne[2]); + (Q->ne[0] % 32 == 0) && + (V->ne[0] % 4 == 0) && + (K->type == GGML_TYPE_F16); const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); @@ -1959,7 +1901,6 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } - static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2008,12 +1949,7 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2564,12 +2500,19 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_CLAMP: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_FILL: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_LOG: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQR: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQRT: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SIN: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_COS: return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: @@ -2577,6 +2520,7 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_ARGMAX: return ggml_webgpu_argmax(ctx, src0, node); case GGML_OP_ARGSORT: + return ggml_webgpu_argsort(ctx, src0, node); case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k return ggml_webgpu_argsort(ctx, src0, node); @@ -2598,22 +2542,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); -<<<<<<< HEAD - std::vector commands; - std::vector subs; - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; - -======= std::vector commands; std::vector futures; uint32_t batch_param_bufs = 0; uint32_t num_batched_kernels = 0; ->>>>>>> 30923ffc9 (add vectorized flash attention) for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { - contains_set_rows = true; - } if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { const uint32_t cmd_param_bufs = (uint32_t) cmd->params_bufs.size(); @@ -2634,46 +2567,25 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str num_batched_kernels += cmd->num_kernels; } -<<<<<<< HEAD - if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); -======= if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE || (!commands.empty() && batch_param_bufs + WEBGPU_MAX_PARAM_BUFS_PER_CMD > WEBGPU_NUM_PARAM_BUFS)) { futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool)); ->>>>>>> 30923ffc9 (add vectorized flash attention) // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); + ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); commands.clear(); batch_param_bufs = 0; num_batched_kernels = 0; } } if (!commands.empty()) { - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); - commands.clear(); - } - - // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. - if (contains_set_rows) { - wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, - ctx->set_rows_host_error_buf.GetSize()); - wgpu::CommandBuffer set_rows_commands = encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &set_rows_commands); - ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, - ctx->set_rows_host_error_buf.GetSize()); - const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - ctx->set_rows_host_error_buf.Unmap(); + webgpu_submission_futures new_futures = + ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); + futures.push_back(new_futures); } - ggml_backend_webgpu_wait(ctx->global_ctx, subs); + ggml_backend_webgpu_wait(ctx->global_ctx, futures); WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3292,12 +3204,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); - ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, - WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); - ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf, - WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); + webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); From 360274357c680bce60b1e8c0a37b4296499e8b94 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Sat, 7 Mar 2026 23:21:17 -0800 Subject: [PATCH 12/34] update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select --- .../wgsl-shaders/flash_attn_vec_split.wgsl | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index c5b7372ed30c..793a85e7e66d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -172,19 +172,15 @@ var exp_sum_shmem: array; var blk_state_wg: u32; fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { - var v = FLOAT_MIN; - if (kv_idx < KV_TILE) { - v = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale; - } + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - var mask_val = 0.0; - if (kv_idx < KV_TILE) { - mask_val = f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]); - } + var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); v += select(mask_val, slope * mask_val, has_bias); } #endif @@ -247,20 +243,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let has_bias = params.max_bias > 0.0; let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); - // Load Q tile once and keep it in f16 to match scalar/Metal precision path. + // load q tile into shared memory for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { let q_row = elem_idx / HEAD_DIM_QK; let q_col = elem_idx % HEAD_DIM_QK; - let global_q_row = q_row_start + q_row; - let global_q_row_offset = q_head_offset + global_q_row * params.stride_q1; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; q_shmem[elem_idx] = f16(select( 0.0, Q[global_q_row_offset + q_col], - global_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); } - workgroupBarrier(); - for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { #ifdef BLK let q_blk = q_row_start / Q_TILE; From 356d6ff696afc14578c49859236635ce9b3303fc Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Sun, 8 Mar 2026 00:11:36 -0800 Subject: [PATCH 13/34] enable vec path for q4 and q8 --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +- .../wgsl-shaders/flash_attn_vec_split.wgsl | 306 ++++++++++-------- 2 files changed, 175 insertions(+), 135 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index abf64ea87243..e2cb6464368f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1338,10 +1338,12 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && - (K->type == GGML_TYPE_F16); + kv_vec_type_supported && + (V->type == K->type); const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 793a85e7e66d..88950917f517 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -99,7 +99,11 @@ struct Params { @group(0) @binding(0) var Q: array; @group(0) @binding(1) var K: array; +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(2) var V: array; +#else @group(0) @binding(2) var V: array>; +#endif #if defined(MASK) && defined(SINKS) @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @@ -276,61 +280,74 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load k tile into shared memory -// #if defined(KV_Q4_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let k_row = blck_idx / BLOCKS_K; -// let global_k_row = kv_tile + k_row; -// let block_k = blck_idx % BLOCKS_K; -// let row_offset = k_row * HEAD_DIM_QK; - -// if (global_k_row < params.seq_len_kv) { -// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = K[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = K[base_idx + 1u + block_offset + j]; -// let q_1 = K[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte(q_packed, k); -// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; -// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_lo; -// kv_shmem[row_offset + idx + 16u] = q_hi; -// } -// } -// } -// } -// #elif defined(KV_Q8_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let k_row = blck_idx / BLOCKS_K; -// let global_k_row = kv_tile + k_row; -// let block_k = blck_idx % BLOCKS_K; -// let row_offset = k_row * HEAD_DIM_QK; - -// if (global_k_row < params.seq_len_kv) { -// let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = K[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = K[base_idx + 1u + block_offset + j]; -// let q_1 = K[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte_i32(q_packed, k); -// let q_val = f16(q_byte) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_val; -// } -// } -// } -// } -// #endif +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +#endif workgroupBarrier(); @@ -356,19 +373,28 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (kv_valid) { for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { let q_off = local_q_row_offset + i * 4u; - let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); - let qv = vec4(f32(q_shmem[q_off + 0u]), - f32(q_shmem[q_off + 1u]), - f32(q_shmem[q_off + 2u]), - f32(q_shmem[q_off + 3u])); - let kv = vec4(f32(K[idx + 0u]), - f32(K[idx + 1u]), - f32(K[idx + 2u]), - f32(K[idx + 3u])); - + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); +#ifdef KV_DIRECT + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + let kv = vec4( + f32(K[idx + 0u]), + f32(K[idx + 1u]), + f32(K[idx + 2u]), + f32(K[idx + 3u])); +#else + let idx = kv_idx * HEAD_DIM_QK + (i * 4u); + let kv = vec4( + f32(kv_shmem[idx + 0u]), + f32(kv_shmem[idx + 1u]), + f32(kv_shmem[idx + 2u]), + f32(kv_shmem[idx + 3u])); +#endif partial_sum += dot(qv, kv); - } } var sum = partial_sum; @@ -465,74 +491,77 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load v tile into shared memory -// #if defined(KV_Q4_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let v_row = blck_idx / BLOCKS_V; -// let global_v_row = kv_tile + v_row; -// let block_k = blck_idx % BLOCKS_V; -// let row_offset = v_row * HEAD_DIM_V; - -// if (global_v_row < params.seq_len_kv) { -// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = V[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = V[base_idx + 1u + block_offset + j]; -// let q_1 = V[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte(q_packed, k); -// let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; -// let q_lo = (f16(q_byte & 0xF) - 8.0) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_lo; -// kv_shmem[row_offset + idx + 16u] = q_hi; -// } -// } -// } -// } -// #elif defined(KV_Q8_0) -// for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { -// let blck_idx = elem_idx / BLOCK_SIZE; -// let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; -// let v_row = blck_idx / BLOCKS_V; -// let global_v_row = kv_tile + v_row; -// let block_k = blck_idx % BLOCKS_V; -// let row_offset = v_row * HEAD_DIM_V; - -// if (global_v_row < params.seq_len_kv) { -// let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; -// let base_idx = global_block_idx * F16_PER_BLOCK; -// let d = V[base_idx]; // scale -// for (var j = 0u; j < F16_PER_THREAD; j += 2) { -// let q_0 = V[base_idx + 1u + block_offset + j]; -// let q_1 = V[base_idx + 1u + block_offset + j + 1]; -// let q_packed = bitcast(vec2(q_0, q_1)); -// for (var k = 0u; k < 4u; k++) { -// let q_byte = get_byte_i32(q_packed, k); -// let q_val = f16(q_byte) * d; -// let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; -// kv_shmem[row_offset + idx] = q_val; -// } -// } -// } -// } -// #elif defined(KV_DIRECT) -// // Direct global loads for KV -// #else -// for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { -// let v_row = elem_idx / HEAD_DIM_V; -// let v_col = elem_idx % HEAD_DIM_V; -// let global_v_row = kv_tile + v_row; -// let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; -// kv_shmem[elem_idx] = f16(select( -// 0.0, -// V[global_v_row_offset + v_col], -// global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); -// } -// #endif +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(v4.x); + kv_shmem[elem_idx + 1u] = f16(v4.y); + kv_shmem[elem_idx + 2u] = f16(v4.z); + kv_shmem[elem_idx + 3u] = f16(v4.w); + } +#endif workgroupBarrier(); @@ -556,8 +585,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); +#ifdef KV_DIRECT let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; let v4 = vec4(V[v_idx >> 2u]); +#else + let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; + let v4 = vec4( + f32(kv_shmem[v_idx + 0u]), + f32(kv_shmem[v_idx + 1u]), + f32(kv_shmem[v_idx + 2u]), + f32(kv_shmem[v_idx + 3u])); +#endif lo += p * v4; } From 1ae041d4e4e1bce71b425c33f1deaceb922c7fe2 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Sun, 8 Mar 2026 00:18:56 -0800 Subject: [PATCH 14/34] flash-attn vec nwg=1 fast path (skip tmp/reduce staging) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 45 +++++++++++------ .../wgsl-shaders/flash_attn_vec_split.wgsl | 50 +++++++++---------- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e2cb6464368f..023fa3115f93 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1417,20 +1417,32 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, // For a single split workgroup there is nothing to merge. // Let vec split write final dst directly and skip reduce. const bool use_vec_reduce = nwg > 1u; - - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_base = tmp_data_elems; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const uint64_t tmp_total_elems = tmp_data_elems + tmp_stats_elems; - const uint64_t tmp_size_bytes = - ROUNDUP_POW2(tmp_total_elems * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - GGML_ASSERT(tmp_stats_base <= UINT32_MAX); GGML_ASSERT(nrows <= UINT32_MAX); - wgpu::Buffer tmp_buf; - ggml_webgpu_create_buffer(ctx->global_ctx->device, tmp_buf, tmp_size_bytes, - wgpu::BufferUsage::Storage, - "flash_attn_vec_tmp"); + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + tmp_stats_base = tmp_data_elems; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const uint64_t tmp_total_elems = tmp_data_elems + tmp_stats_elems; + tmp_size_bytes = ROUNDUP_POW2(tmp_total_elems * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + + ggml_webgpu_create_buffer(ctx->global_ctx->device, tmp_buf, tmp_size_bytes, wgpu::BufferUsage::Storage, + "flash_attn_vec_tmp"); + tmp_bind_offset = 0; + tmp_bind_size = tmp_size_bytes; + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + } webgpu_pipeline blk_pipeline; std::vector blk_params; @@ -1532,8 +1544,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, } split_entries.push_back({ .binding = split_binding_index++, .buffer = tmp_buf, - .offset = 0, - .size = tmp_size_bytes }); + .offset = tmp_bind_offset, + .size = tmp_bind_size }); split_entries.push_back({ .binding = split_binding_index++, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), @@ -1609,7 +1621,10 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const bool split_passes = use_blk || use_vec_reduce; - std::vector retained_buffers = { tmp_buf }; + std::vector retained_buffers; + if (use_vec_reduce) { + retained_buffers.push_back(tmp_buf); + } if (use_blk) { retained_buffers.push_back(blk_buf); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 88950917f517..045c25938e55 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -681,31 +681,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_q_row = q_row_start + q_tile_row; if (global_q_row >= params.seq_len_q) { break; } - let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; - let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; - let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; - - for (var elem_base = sg_inv_id * 4u; - elem_base < HEAD_DIM_V; - elem_base += subgroup_size * 4u) { - - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - - let tbase = tmp_row_data_base + elem_base; - tmp[tbase + 0u] = f32(o_shmem[i0]); - tmp[tbase + 1u] = f32(o_shmem[i1]); - tmp[tbase + 2u] = f32(o_shmem[i2]); - tmp[tbase + 3u] = f32(o_shmem[i3]); - } - - if (sg_inv_id == 0u) { - tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; - tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; - } - if (params.nwg == 1u) { let exp_sum = exp_sum_shmem[q_tile_row]; let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); @@ -728,6 +703,31 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let dst_vec_index: u32 = (row_base + elem_base) >> 2u; dst[dst_vec_index] = v; } + } else { + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; + let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let tbase = tmp_row_data_base + elem_base; + tmp[tbase + 0u] = f32(o_shmem[i0]); + tmp[tbase + 1u] = f32(o_shmem[i1]); + tmp[tbase + 2u] = f32(o_shmem[i2]); + tmp[tbase + 3u] = f32(o_shmem[i3]); + } + + if (sg_inv_id == 0u) { + tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + } } } } From 33a547e17a2a9f010e3181723301eb2cc071f600 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Sun, 8 Mar 2026 00:28:39 -0800 Subject: [PATCH 15/34] use packed f16 K loads in flash-attn vec split --- .../wgsl-shaders/flash_attn_vec_split.wgsl | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 045c25938e55..0ce96f0bdbe2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -98,7 +98,11 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif #if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(2) var V: array; #else @@ -337,15 +341,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #elif defined(KV_DIRECT) // Direct global loads for KV #else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { let k_row = elem_idx / HEAD_DIM_QK; let k_col = elem_idx % HEAD_DIM_QK; let global_k_row = kv_tile + k_row; let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - kv_shmem[elem_idx] = f16(select( - 0.0, - K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(k4.x); + kv_shmem[elem_idx + 1u] = f16(k4.y); + kv_shmem[elem_idx + 2u] = f16(k4.z); + kv_shmem[elem_idx + 3u] = f16(k4.w); } #endif @@ -381,11 +388,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, f32(q_shmem[q_off + 3u])); #ifdef KV_DIRECT let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); - let kv = vec4( - f32(K[idx + 0u]), - f32(K[idx + 1u]), - f32(K[idx + 2u]), - f32(K[idx + 3u])); + let kv = vec4(K[idx >> 2u]); #else let idx = kv_idx * HEAD_DIM_QK + (i * 4u); let kv = vec4( From 638c49b4ed365e811dfc0cb89e2f4f1d09d4d2e5 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Sun, 8 Mar 2026 00:29:25 -0800 Subject: [PATCH 16/34] use packed f16 K loads in flash-attn vec split on host side --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 023fa3115f93..b5295808e8b1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1335,7 +1335,12 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + + bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && + (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; @@ -1343,6 +1348,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); const uint32_t vec_nwg_cap = From 0abac3984b79e757eef0781f4b4aa3c154a619e4 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Sun, 8 Mar 2026 00:32:36 -0800 Subject: [PATCH 17/34] tune flash-attn vec f16 VEC_NE by head dim --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index d0ad4266006a..e85be9ad1e15 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -325,6 +325,23 @@ struct ggml_webgpu_flash_attn_shader_decisions { uint32_t wg_size = 0; }; +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + // Head-dim specializations used by the tuned vec f16 path. + switch (key.head_dim_qk) { + case 64: return 2u; + case 96: return 4u; + case 128: return 1u; + case 192: return 2u; + case 576: return 2u; + default: return 1u; + } +} + struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { uint32_t head_dim_v; uint32_t wg_size; @@ -529,7 +546,8 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; GGML_ASSERT(kv_tile % context.sg_mat_n == 0); GGML_ASSERT(kv_tile <= max_kv_tile); - defines.push_back("VEC_NE=1"); + const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); } if (context.key.kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); From 83a42b36b324ec833c7d4d9ffe7758e3f5b53b66 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 13:40:22 -0700 Subject: [PATCH 18/34] cleanup --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 18 +++------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 36 +++++++------------ 2 files changed, 18 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e85be9ad1e15..e4929ec77ddb 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -282,14 +282,11 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_sinks; bool uses_logit_softcap; bool use_vec; - bool use_vec_split; - bool use_blk; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec && - use_vec_split == other.use_vec_split && use_blk == other.use_blk; + uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; } }; @@ -304,8 +301,6 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); ggml_webgpu_hash_combine(seed, key.use_vec); - ggml_webgpu_hash_combine(seed, key.use_vec_split); - ggml_webgpu_hash_combine(seed, key.use_blk); return seed; } }; @@ -422,8 +417,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader variant += std::string("_kvt") + std::to_string(context.key.kv_tile); uint32_t wg_size = 1; - const uint32_t target_wg = std::min(32u, context.max_wg_size); - while ((wg_size << 1) <= target_wg) { + while ((wg_size << 1) <= context.max_wg_size) { wg_size <<= 1; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); @@ -520,7 +514,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - if (context.key.use_blk) { + if (context.key.has_mask && context.key.use_vec) { defines.push_back("BLK"); variant += "_blk"; } @@ -539,13 +533,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( uint32_t q_tile = context.sg_mat_m; uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.use_vec_split) { + if (context.key.use_vec) { q_tile = 1; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context); kv_tile = std::max(context.sg_mat_n, std::min(32u, max_kv_tile)); kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; - GGML_ASSERT(kv_tile % context.sg_mat_n == 0); - GGML_ASSERT(kv_tile <= max_kv_tile); const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); } @@ -561,7 +553,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); uint32_t wg_size = 0; - if (context.key.use_vec_split) { + if (context.key.use_vec) { // Keep vec-split to exactly one subgroup to preserve thread mapping. wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); } else { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b5295808e8b1..ad883f9cd8c8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1339,9 +1339,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); - bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && - (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && + (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool use_vec = (Q->ne[1] < 20) && @@ -1353,8 +1353,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_vec_split = use_vec && vec_nwg_cap > 1u; - const bool use_blk = use_vec_split && has_mask; + const bool use_blk = use_vec && has_mask; ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1365,8 +1364,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .has_sinks = static_cast(has_sinks), .uses_logit_softcap = logit_softcap != 0.0f, .use_vec = use_vec, - .use_vec_split = use_vec_split, - .use_blk = use_blk, }; webgpu_pipeline pipeline; @@ -1384,7 +1381,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, }; ggml_webgpu_processed_shader processed; - if (use_vec_split) { + if (use_vec) { processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_vec_split, shader_lib_ctx); } else { processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); @@ -1400,24 +1397,20 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - uint32_t vec_nwg = 1u; - if (use_vec_split) { - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * vec_nwg * kv_span) < (uint64_t) K->ne[1] && vec_nwg < vec_nwg_cap) { - vec_nwg <<= 1; - } - vec_nwg = std::min(vec_nwg, vec_nwg_cap); - } - - bool have_blk_buf = false; wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; uint32_t blk_nblk0 = 0; uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - if (use_vec_split) { - const uint32_t nwg = vec_nwg; + if (use_vec) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; // For a single split workgroup there is nothing to merge. @@ -1466,8 +1459,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_create_buffer(ctx->global_ctx->device, blk_buf, blk_size_bytes, wgpu::BufferUsage::Storage, "flash_attn_vec_blk"); - have_blk_buf = true; - const ggml_webgpu_flash_attn_blk_pipeline_key blk_key = { .q_tile = decisions->q_tile, .kv_tile = decisions->kv_tile, @@ -1542,7 +1533,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); } if (use_blk) { - GGML_ASSERT(have_blk_buf); split_entries.push_back({ .binding = split_binding_index++, .buffer = blk_buf, .offset = 0, From 2595b1acec157ca2e9554075cdc16dfc440f1b18 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 14:19:21 -0700 Subject: [PATCH 19/34] cleanup --- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 1 - .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 16 ++++----- .../wgsl-shaders/flash_attn_vec_split.wgsl | 36 +++++++++---------- 3 files changed, 24 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index afffc8f78fa7..82556d060bf8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -8,7 +8,6 @@ struct Params { offset_mask: u32, seq_len_q: u32, seq_len_kv: u32, - // plane b base = offset_mask + b * stride_mask3. stride_mask3: u32, // Number of KV blocks and Q blocks per batch. // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index 43dbf7d3aa08..b71d13d4e499 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -42,25 +42,25 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let dst3_stride = dst2_stride * params.seq_len_q; let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; - let lane = sg_inv_id; + let thread = sg_inv_id; if (params.nwg > subgroup_size) { return; } let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); - let active_lane = lane < params.nwg; - let si = select(0.0, tmp[stats_base + 2u * lane + 0u], active_lane); - let mi = select(FLOAT_MIN, tmp[stats_base + 2u * lane + 1u], active_lane); + let active_thread = thread < params.nwg; + let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread); + let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread); let m = subgroupMax(mi); - let ms = select(0.0, exp(mi - m), active_lane); + let ms = select(0.0, exp(mi - m), active_thread); let s = subgroupAdd(si * ms); let inv_s = select(0.0, 1.0 / s, s != 0.0); let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { var weighted = vec4(0.0, 0.0, 0.0, 0.0); - if (active_lane) { - let src = row_tmp_base + lane * HEAD_DIM_V + elem_base; + if (active_thread) { + let src = row_tmp_base + thread * HEAD_DIM_V + elem_base; weighted = vec4(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms; } @@ -69,7 +69,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_z = subgroupAdd(weighted.z); let sum_w = subgroupAdd(weighted.w); - if (lane == 0u) { + if (thread == 0u) { let dst_vec_index = (row_base + elem_base) >> 2u; dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 0ce96f0bdbe2..b69fe8994a32 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -41,7 +41,6 @@ enable chromium_experimental_subgroup_matrix; #endif #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) -// Ok not to put these in a define block, compiler will remove if unused fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; } @@ -162,15 +161,13 @@ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var kv_shmem: array; #endif -var o_shmem: array; // output shmem +var o_shmem: array; #ifdef MASK // storage for mask values var mask_shmem: array; #endif -// storage for output of Q*K^T scores for online softmax (S matrix from paper) -// also storage for diagonal matrix during online softmax (P matrix from paper) // note that we reuse the same storage for both since we only need one at a time var inter_shmem: array; @@ -401,7 +398,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } var sum = partial_sum; - // Reduce over tx lanes (NL) for this ty stripe. + // Reduce over tx threads (NL) for this ty stripe. var tx_delta = num_of_threads >> 1u; loop { if (tx_delta == 0u) { @@ -452,7 +449,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, break; } - // initialize running max for this row var prev_max = row_max_shmem[q_tile_row]; var final_max = prev_max; // pass 1: compute final max across the full KV tile in chunks @@ -571,17 +567,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (!skip_tile) { // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V across the full KV tile - let ne_lanes : u32 = VEC_NE; - let nl_lanes = max(1u, subgroup_size / ne_lanes); - let tx_pv = sg_inv_id % nl_lanes; - let ty_pv = sg_inv_id / nl_lanes; + let ne_threads : u32 = VEC_NE; + let nl_threads = max(1u, subgroup_size / ne_threads); + let tx_pv = sg_inv_id % nl_threads; + let ty_pv = sg_inv_id / nl_threads; for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_lanes) { + for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { var lo = vec4(0.0, 0.0, 0.0, 0.0); - for (var cc = 0u; cc < KV_TILE / ne_lanes; cc += 1u) { - let kv_idx = cc * ne_lanes + ty_pv; + for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { + let kv_idx = cc * ne_threads + ty_pv; let v_row = kv_tile + kv_idx; if (v_row >= params.seq_len_kv) { continue; @@ -606,17 +602,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var lo_y = lo.y; var lo_z = lo.z; var lo_w = lo.w; - // Reduce over ty lanes (NE) for this tx lane. - var ty_delta = ne_lanes >> 1u; + // Reduce over ty threads (NE) for this tx thread. + var ty_delta = ne_threads >> 1u; loop { if (ty_delta == 0u) { break; } - let lane_delta = ty_delta * nl_lanes; - let shx = subgroupShuffleDown(lo_x, lane_delta); - let shy = subgroupShuffleDown(lo_y, lane_delta); - let shz = subgroupShuffleDown(lo_z, lane_delta); - let shw = subgroupShuffleDown(lo_w, lane_delta); + let thread_delta = ty_delta * nl_threads; + let shx = subgroupShuffleDown(lo_x, thread_delta); + let shy = subgroupShuffleDown(lo_y, thread_delta); + let shz = subgroupShuffleDown(lo_z, thread_delta); + let shw = subgroupShuffleDown(lo_w, thread_delta); if (ty_pv < ty_delta) { lo_x += shx; lo_y += shy; From 25096b9c40b9f4b80b36880995e24a7f6ef2c802 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 14:44:39 -0700 Subject: [PATCH 20/34] keep host side clean --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index ad883f9cd8c8..d4a51d3fc9ab 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -20,11 +20,18 @@ #include #include #include +#ifdef GGML_WEBGPU_GPU_PROFILE +# include +#endif +#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) +# include +#endif #include #include #include #include #include +#include #include #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) @@ -69,23 +76,23 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 +# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps #endif /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 48u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u +#define WEBGPU_NUM_PARAM_BUFS 96u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE +#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) #define WEBGPU_MAX_PARAM_BUFS_PER_CMD 4u #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 // For operations which process a row in parallel, this seems like a reasonable // default @@ -497,9 +504,11 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ct break; case wgpu::WaitStatus::Error: GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + i++; break; default: GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + i++; break; } } From 3d6bfe0298e9d4e55bf8d583fc3f0aefbc04d6d6 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 15:08:34 -0700 Subject: [PATCH 21/34] cleanup host side --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 197 +++++++++------------------ 1 file changed, 61 insertions(+), 136 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d4a51d3fc9ab..b9dc82fa870d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -90,7 +90,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) #define WEBGPU_MAX_PARAM_BUFS_PER_CMD 4u #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -295,13 +294,9 @@ struct webgpu_command { uint32_t num_kernels; wgpu::CommandBuffer commands; std::vector params_bufs; - std::optional set_rows_error_bufs; - // Keep temporary resources alive until submitted work is complete. - std::vector retained_buffers; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; - bool has_timestamp_query = false; #endif }; @@ -382,7 +377,8 @@ struct webgpu_context_struct { pre_wgsl::Preprocessor p; webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; std::map> cpy_pipelines; // src_type, dst_type std::unordered_map @@ -547,12 +543,9 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, std::vector commands, - webgpu_buf_pool & param_buf_pool, - webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { + webgpu_buf_pool & param_buf_pool) { std::vector command_buffers; std::vector params_bufs; - std::vector set_rows_error_bufs; - std::vector retained_buffers; #ifdef GGML_WEBGPU_GPU_PROFILE std::vector> pipeline_name_and_ts_bufs; #endif @@ -560,10 +553,6 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex for (const auto & command : commands) { command_buffers.push_back(command.commands); params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); - retained_buffers.insert(retained_buffers.end(), command.retained_buffers.begin(), command.retained_buffers.end()); - if (command.set_rows_error_bufs) { - set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); - } } ctx->queue.Submit(command_buffers.size(), command_buffers.data()); @@ -571,41 +560,17 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, - [¶m_buf_pool, params_bufs, retained_buffers](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } - (void) retained_buffers; // Free the staged buffers param_buf_pool.free_bufs(params_bufs); }); futures.push_back({ p_f }); - for (const auto & bufs : set_rows_error_bufs) { - wgpu::Future f = bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); - } else { - const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - // We can't unmap in here due to WebGPU reentrancy limitations. - if (set_rows_error_buf_pool) { - set_rows_error_buf_pool->free_bufs({ bufs }); - } - } - }); - futures.push_back({ f }); - } - #ifdef GGML_WEBGPU_GPU_PROFILE for (const auto & command : commands) { - if (!command.has_timestamp_query) { - continue; - } auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; @@ -635,9 +600,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list, - const std::optional & set_rows_error_bufs = std::nullopt, - bool split_passes = false) { + const std::vector> & workgroups_list) { GGML_ASSERT(pipelines.size() == params_list.size()); GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); @@ -675,73 +638,39 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < params_bufs_list.size(); i++) { ctx->queue.WriteBuffer(params_bufs_list[i].dev_buf, 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } - - // If there are SET_ROWS operations in this submission, copy their error - // buffers to the host. - if (set_rows_error_bufs) { - encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, - set_rows_error_bufs->host_buf.GetSize()); - } - - bool profile_pass = false; #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = {}; - if (!split_passes) { - ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); - } - profile_pass = true; + webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); + if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + ts_bufs.host_buf.Unmap(); } -#endif -#ifndef GGML_WEBGPU_GPU_PROFILE - GGML_UNUSED(profile_pass); -#endif - if (split_passes) { - for (size_t i = 0; i < pipelines.size(); i++) { - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); - pass.End(); - } - } else { -#ifdef GGML_WEBGPU_GPU_PROFILE - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); + wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, + .beginningOfPassWriteIndex = 0, + .endOfPassWriteIndex = 1 }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); #else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); #endif - for (size_t i = 0; i < pipelines.size(); i++) { - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); - } - pass.End(); + for (size_t i = 0; i < pipelines.size(); i++) { + pass.SetPipeline(pipelines[i].pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); } + pass.End(); #ifdef GGML_WEBGPU_GPU_PROFILE - if (profile_pass) { - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); - } + encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); + encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); #endif wgpu::CommandBuffer commands = encoder.Finish(); webgpu_command result = {}; result.commands = commands; result.params_bufs = params_bufs_list; - result.set_rows_error_bufs = set_rows_error_bufs; result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE - if (profile_pass) { - result.timestamp_query_bufs = ts_bufs; - result.has_timestamp_query = true; - } + result.timestamp_query_bufs = ts_bufs; // TODO: handle multiple pipeline names result.pipeline_name = pipelines.front().name; #endif @@ -754,13 +683,11 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, - uint32_t wg_y = 1, - std::optional set_rows_error_bufs = std::nullopt) { + uint32_t wg_y = 1) { return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, { pipeline - }, - { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs); + }, { params }, { bind_group_entries }, { { wg_x, wg_y } }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -1000,14 +927,6 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); - std::optional error_bufs = std::nullopt; - if (decisions->i64_idx) { - error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); - if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - error_bufs->host_buf.Unmap(); - } - } - std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), @@ -1041,7 +960,10 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, if (decisions->i64_idx) { entries.push_back( - { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); + { .binding = 3, + .buffer = ctx->set_rows_dev_error_buf, + .offset = 0, + .size = ctx->set_rows_dev_error_buf.GetSize() }); } uint32_t threads; @@ -1051,8 +973,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, - error_bufs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1624,22 +1545,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, workgroups_list.push_back({ (uint32_t) nrows, 1u }); } - const bool split_passes = use_blk || use_vec_reduce; - - std::vector retained_buffers; - if (use_vec_reduce) { - retained_buffers.push_back(tmp_buf); - } - if (use_blk) { - retained_buffers.push_back(blk_buf); - } - - webgpu_command cmd; - cmd = ggml_backend_webgpu_build_multi( - ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list, - std::nullopt, split_passes); - cmd.retained_buffers = std::move(retained_buffers); - return cmd; + return ggml_backend_webgpu_build_multi( + ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list); } return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); @@ -2566,17 +2473,20 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str std::vector commands; std::vector futures; - uint32_t batch_param_bufs = 0; - uint32_t num_batched_kernels = 0; + uint32_t batch_param_bufs = 0; + uint32_t num_batched_kernels = 0; + bool contains_set_rows = false; for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + contains_set_rows = true; + } if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { const uint32_t cmd_param_bufs = (uint32_t) cmd->params_bufs.size(); // Leave room for the next command so alloc_bufs() never blocks waiting for // a submit that has not yet happened. if (!commands.empty() && batch_param_bufs + cmd_param_bufs + WEBGPU_MAX_PARAM_BUFS_PER_CMD > WEBGPU_NUM_PARAM_BUFS) { - futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, - &ctx->set_rows_error_buf_pool)); + futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); ctx->global_ctx->instance.ProcessEvents(); ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); commands.clear(); @@ -2591,8 +2501,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE || (!commands.empty() && batch_param_bufs + WEBGPU_MAX_PARAM_BUFS_PER_CMD > WEBGPU_NUM_PARAM_BUFS)) { - futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, - &ctx->set_rows_error_buf_pool)); + futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); @@ -2602,11 +2511,25 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } } if (!commands.empty()) { - webgpu_submission_futures new_futures = - ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); + webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool); futures.push_back(new_futures); } + if (contains_set_rows) { + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer set_rows_commands = encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &set_rows_commands); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); + } + ggml_backend_webgpu_wait(ctx->global_ctx, futures); WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; @@ -3226,10 +3149,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); - webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, - WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); From 68fa272676fd82ed82d3690c5b345ca30b404afd Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 15:23:34 -0700 Subject: [PATCH 22/34] change back to original host wait/submit behavior --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 196 +++++++++++++++++---------- 1 file changed, 125 insertions(+), 71 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b9dc82fa870d..8249583ce2f6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -129,11 +129,6 @@ struct webgpu_pool_bufs { wgpu::Buffer dev_buf; }; -// The futures to wait on for a single queue submission -struct webgpu_submission_futures { - std::vector futures; -}; - // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { std::vector free; @@ -368,6 +363,13 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; +struct webgpu_submission { + wgpu::FutureWaitInfo submit_done; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector profile_futures; +#endif +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -478,34 +480,104 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ +static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { + switch (status) { + case wgpu::WaitStatus::Success: + return true; + case wgpu::WaitStatus::TimedOut: + if (allow_timeout) { + return false; + } + GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); + return false; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + return false; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + return false; + } +} + +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { + futures.erase(std::remove_if(futures.begin(), futures.end(), + [](const wgpu::FutureWaitInfo & info) { return info.completed; }), + futures.end()); +} + +static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, + std::vector & futures, + bool block) { + if (futures.empty()) { + return; + } + + uint64_t timeout_ms = block ? UINT64_MAX : 0; + if (block) { + while (!futures.empty()) { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { + ggml_backend_webgpu_erase_completed_futures(futures); + } + } + } else { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { + ggml_backend_webgpu_erase_completed_futures(futures); + } + } +} +#endif + // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & futures, - bool block = true) { +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, + std::vector & subs, + bool block = true) { // If we have too many in-flight submissions, wait on the oldest one first. - uint64_t timeout_ms = block ? UINT64_MAX : 0; - while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); - futures.erase(futures.begin()); - } - size_t i = 0; - while (i < futures.size()) { - auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); - switch (waitStatus) { - case wgpu::WaitStatus::Success: - futures.erase(futures.begin() + i); - break; - case wgpu::WaitStatus::TimedOut: - i++; - break; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - i++; - break; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - i++; - break; + if (subs.empty()) { + return; + } + while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX); + if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); +#endif + subs.erase(subs.begin()); + } + } + + if (subs.empty()) { + return; + } + + if (block) { + for (auto & sub : subs) { + while (!sub.submit_done.completed) { + auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX); + ggml_backend_webgpu_handle_wait_status(waitStatus); + } +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true); +#endif + } + subs.clear(); + } else { + // Poll each submit future once and remove completed submissions. + for (auto sub = subs.begin(); sub != subs.end();) { + auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); + ggml_backend_webgpu_handle_wait_status(waitStatus, true); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); + if (sub->submit_done.completed && sub->profile_futures.empty()) { +#else + if (sub->submit_done.completed) { +#endif + sub = subs.erase(sub); + } else { + ++sub; + } } } } @@ -541,11 +613,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, - std::vector commands, - webgpu_buf_pool & param_buf_pool) { +static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, + std::vector & commands, + webgpu_buf_pool & param_buf_pool) { std::vector command_buffers; std::vector params_bufs; + webgpu_submission submission; #ifdef GGML_WEBGPU_GPU_PROFILE std::vector> pipeline_name_and_ts_bufs; #endif @@ -556,8 +629,6 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex } ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - std::vector futures; - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { @@ -567,7 +638,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex // Free the staged buffers param_buf_pool.free_bufs(params_bufs); }); - futures.push_back({ p_f }); + submission.submit_done = { p_f }; #ifdef GGML_WEBGPU_GPU_PROFILE for (const auto & command : commands) { @@ -588,10 +659,10 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); } }); - futures.push_back({ f }); + submission.profile_futures.push_back({ f }); } #endif - return { futures }; + return submission; } static webgpu_command ggml_backend_webgpu_build_multi( @@ -704,9 +775,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }, - ctx->memset_buf_pool) }; - ggml_backend_webgpu_wait(ctx, futures); + std::vector commands = { command }; + std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; + ggml_backend_webgpu_wait(ctx, sub); } /** End WebGPU Actions */ @@ -2471,48 +2542,31 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector futures; - uint32_t batch_param_bufs = 0; - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; + std::vector commands; + std::vector subs; + uint32_t num_batched_kernels = 0; + bool contains_set_rows = false; for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { - const uint32_t cmd_param_bufs = (uint32_t) cmd->params_bufs.size(); - - // Leave room for the next command so alloc_bufs() never blocks waiting for - // a submit that has not yet happened. - if (!commands.empty() && batch_param_bufs + cmd_param_bufs + WEBGPU_MAX_PARAM_BUFS_PER_CMD > WEBGPU_NUM_PARAM_BUFS) { - futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); - ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); - commands.clear(); - batch_param_bufs = 0; - num_batched_kernels = 0; - } - commands.push_back(*cmd); - batch_param_bufs += cmd_param_bufs; - num_batched_kernels += cmd->num_kernels; + num_batched_kernels += cmd.value().num_kernels; } - if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE || - (!commands.empty() && batch_param_bufs + WEBGPU_MAX_PARAM_BUFS_PER_CMD > WEBGPU_NUM_PARAM_BUFS)) { - futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + num_batched_kernels = 0; + subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); + ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); commands.clear(); - batch_param_bufs = 0; - num_batched_kernels = 0; } } if (!commands.empty()) { - webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool); - futures.push_back(new_futures); + subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + commands.clear(); } if (contains_set_rows) { @@ -2530,7 +2584,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->set_rows_host_error_buf.Unmap(); } - ggml_backend_webgpu_wait(ctx->global_ctx, futures); + ggml_backend_webgpu_wait(ctx->global_ctx, subs); WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } From 5065dc67187dcf33ec95337ccceac0816c68b95c Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 15:25:52 -0700 Subject: [PATCH 23/34] formatting --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 150 ++++++++++++--------------- 1 file changed, 68 insertions(+), 82 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 8249583ce2f6..b86f3786bfc2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -286,9 +286,9 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; + uint32_t num_kernels; + wgpu::CommandBuffer commands; + std::vector params_bufs; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -385,12 +385,14 @@ struct webgpu_context_struct { std::map> cpy_pipelines; // src_type, dst_type std::unordered_map flash_attn_pipelines; - std::unordered_map flash_attn_vec_reduce_pipelines; - std::unordered_map - flash_attn_blk_pipelines; + flash_attn_blk_pipelines; std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split @@ -682,7 +684,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < pipelines.size(); i++) { webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); - std::vector entries = bind_group_entries_list[i]; + std::vector entries = bind_group_entries_list[i]; // Bindings can be sparse (e.g. 0,1,2,4,5), so params must use max(binding)+1. uint32_t params_binding_num = 0; for (const auto & entry : entries) { @@ -707,7 +709,8 @@ static webgpu_command ggml_backend_webgpu_build_multi( wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); for (size_t i = 0; i < params_bufs_list.size(); i++) { - ctx->queue.WriteBuffer(params_bufs_list[i].dev_buf, 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); + ctx->queue.WriteBuffer(params_bufs_list[i].dev_buf, 0, params_list[i].data(), + params_list[i].size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); @@ -743,7 +746,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( #ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; // TODO: handle multiple pipeline names - result.pipeline_name = pipelines.front().name; + result.pipeline_name = pipelines.front().name; #endif return result; } @@ -758,7 +761,8 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, { pipeline - }, { params }, { bind_group_entries }, { { wg_x, wg_y } }); + }, + { params }, { bind_group_entries }, { { wg_x, wg_y } }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -1030,11 +1034,10 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, }; if (decisions->i64_idx) { - entries.push_back( - { .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back({ .binding = 3, + .buffer = ctx->set_rows_dev_error_buf, + .offset = 0, + .size = ctx->set_rows_dev_error_buf.GetSize() }); } uint32_t threads; @@ -1336,25 +1339,21 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && - (Q->ne[0] % 32 == 0) && - (V->ne[0] % 4 == 0) && - kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (V->type == K->type); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = - std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1405,7 +1404,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_batch_count = 0; if (use_vec) { - uint32_t nwg = 1u; + uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { nwg <<= 1; @@ -1413,24 +1412,24 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, nwg = std::min(nwg, vec_nwg_cap); GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; // For a single split workgroup there is nothing to merge. // Let vec split write final dst directly and skip reduce. - const bool use_vec_reduce = nwg > 1u; + const bool use_vec_reduce = nwg > 1u; GGML_ASSERT(nrows <= UINT32_MAX); - uint64_t tmp_stats_base = 0; - uint64_t tmp_size_bytes = 0; - wgpu::Buffer tmp_buf = {}; - uint64_t tmp_bind_offset = 0; - uint64_t tmp_bind_size = 0; + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; if (use_vec_reduce) { const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; tmp_stats_base = tmp_data_elems; const uint64_t tmp_stats_elems = nrows * 2u * nwg; const uint64_t tmp_total_elems = tmp_data_elems + tmp_stats_elems; - tmp_size_bytes = ROUNDUP_POW2(tmp_total_elems * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + tmp_size_bytes = ROUNDUP_POW2(tmp_total_elems * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); GGML_ASSERT(tmp_stats_base <= UINT32_MAX); ggml_webgpu_create_buffer(ctx->global_ctx->device, tmp_buf, tmp_size_bytes, wgpu::BufferUsage::Storage, @@ -1444,21 +1443,20 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); } - webgpu_pipeline blk_pipeline; - std::vector blk_params; + webgpu_pipeline blk_pipeline; + std::vector blk_params; std::vector blk_entries; if (use_blk) { GGML_ASSERT(has_mask); - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - ggml_webgpu_create_buffer(ctx->global_ctx->device, blk_buf, blk_size_bytes, - wgpu::BufferUsage::Storage, + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + ggml_webgpu_create_buffer(ctx->global_ctx->device, blk_buf, blk_size_bytes, wgpu::BufferUsage::Storage, "flash_attn_vec_blk"); const ggml_webgpu_flash_attn_blk_pipeline_key blk_key = { .q_tile = decisions->q_tile, @@ -1472,8 +1470,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .key = blk_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_blk_shader( - ctx->p, wgsl_flash_attn_vec_blk, blk_shader_ctx); + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_blk_shader(ctx->p, wgsl_flash_attn_vec_blk, blk_shader_ctx); blk_pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ctx->flash_attn_blk_pipelines.emplace(blk_key, blk_pipeline); @@ -1498,13 +1496,13 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector split_params = params; if (use_blk) { - split_params.push_back(0u); // blk_base - split_params.push_back(blk_nblk0); // blk_nblk0 - split_params.push_back(blk_nblk1); // blk_nblk1 + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 } - split_params.push_back(0u); // tmp_data_base + split_params.push_back(0u); // tmp_data_base split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base - split_params.push_back(nwg); // nwg + split_params.push_back(nwg); // nwg std::vector split_entries = { { .binding = 0, @@ -1534,26 +1532,23 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); } if (use_blk) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = blk_buf, - .offset = 0, - .size = blk_size_bytes }); + split_entries.push_back( + { .binding = split_binding_index++, .buffer = blk_buf, .offset = 0, .size = blk_size_bytes }); } - split_entries.push_back({ .binding = split_binding_index++, - .buffer = tmp_buf, - .offset = tmp_bind_offset, - .size = tmp_bind_size }); + split_entries.push_back( + { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); split_entries.push_back({ .binding = split_binding_index++, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - webgpu_pipeline reduce_pipeline; - std::vector reduce_params; + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; std::vector reduce_entries; if (use_vec_reduce) { const uint32_t reduce_wg_size = std::max( - 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + 32u, + std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); const ggml_webgpu_flash_attn_vec_reduce_pipeline_key reduce_key = { .head_dim_v = (uint32_t) V->ne[0], .wg_size = reduce_wg_size, @@ -1574,13 +1569,13 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, } reduce_params = { - (uint32_t) nrows, // nrows - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) Q->ne[2], // n_heads - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst - nwg, // nwg - 0u, // tmp_data_base - (uint32_t) tmp_stats_base, // tmp_stats_base + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base }; reduce_entries = { @@ -1616,8 +1611,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, workgroups_list.push_back({ (uint32_t) nrows, 1u }); } - return ggml_backend_webgpu_build_multi( - ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + entries_list, workgroups_list); } return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); @@ -2500,27 +2495,18 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_CLAMP: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_FILL: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_LOG: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQR: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQRT: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SIN: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_COS: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: return ggml_webgpu_argmax(ctx, src0, node); case GGML_OP_ARGSORT: - return ggml_webgpu_argsort(ctx, src0, node); case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k return ggml_webgpu_argsort(ctx, src0, node); @@ -3105,7 +3091,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; #endif - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; // Initialize device From 03d0625f85e3d9a745f5df33f276c8803352cbbe Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 15:40:05 -0700 Subject: [PATCH 24/34] reverted param-buffer pool r ecfactor --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 115 +++++++++++---------------- 1 file changed, 48 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b86f3786bfc2..aa11b29c63cd 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -88,7 +88,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) -#define WEBGPU_MAX_PARAM_BUFS_PER_CMD 4u #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -124,14 +123,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -struct webgpu_pool_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; -}; - // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { - std::vector free; + std::vector free; // The pool must be synchronized because // 1. The memset pool is shared globally by every ggml buffer, @@ -144,7 +138,6 @@ struct webgpu_buf_pool { size_t cur_pool_size; size_t max_pool_size; wgpu::Device device; - wgpu::BufferUsage host_buf_usage; wgpu::BufferUsage dev_buf_usage; size_t buf_size; bool should_grow; @@ -153,53 +146,47 @@ struct webgpu_buf_pool { int num_bufs, size_t buf_size, wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage, bool should_grow = false, size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->host_buf_usage = host_buf_usage; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; + this->max_pool_size = max_pool_size; + this->cur_pool_size = num_bufs; + this->device = device; + this->dev_buf_usage = dev_buf_usage; + this->buf_size = buf_size; + this->should_grow = should_grow; for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back({ host_buf, dev_buf }); + free.push_back(dev_buf); } } - webgpu_pool_bufs alloc_bufs() { + wgpu::Buffer alloc_bufs() { std::unique_lock lock(mutex); if (!free.empty()) { - webgpu_pool_bufs bufs = free.back(); + wgpu::Buffer buf = free.back(); free.pop_back(); - return bufs; + return buf; } // Try growing the pool if no free buffers if (free.empty() && cur_pool_size < max_pool_size && should_grow) { cur_pool_size++; - wgpu::Buffer host_buf; wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - if (!(host_buf && dev_buf)) { + if (!dev_buf) { GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); } - return webgpu_pool_bufs{ host_buf, dev_buf }; + return dev_buf; } cv.wait(lock, [this] { return !free.empty(); }); - webgpu_pool_bufs bufs = free.back(); + wgpu::Buffer buf = free.back(); free.pop_back(); - return bufs; + return buf; } - void free_bufs(std::vector bufs) { + void free_bufs(std::vector bufs) { std::lock_guard lock(mutex); free.insert(free.end(), bufs.begin(), bufs.end()); cv.notify_all(); @@ -207,12 +194,9 @@ struct webgpu_buf_pool { void cleanup() { std::lock_guard lock(mutex); - for (auto & bufs : free) { - if (bufs.host_buf) { - bufs.host_buf.Destroy(); - } - if (bufs.dev_buf) { - bufs.dev_buf.Destroy(); + for (auto & buf : free) { + if (buf) { + buf.Destroy(); } } free.clear(); @@ -286,9 +270,9 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; + uint32_t num_kernels; + wgpu::CommandBuffer commands; + std::vector params_bufs; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -619,7 +603,7 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & std::vector & commands, webgpu_buf_pool & param_buf_pool) { std::vector command_buffers; - std::vector params_bufs; + std::vector params_bufs; webgpu_submission submission; #ifdef GGML_WEBGPU_GPU_PROFILE std::vector> pipeline_name_and_ts_bufs; @@ -657,9 +641,9 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & // WebGPU timestamps are in ns; convert to ms double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; ctx->shader_gpu_time_ms[label] += elapsed_ms; - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); }); submission.profile_futures.push_back({ f }); } @@ -678,24 +662,16 @@ static webgpu_command ggml_backend_webgpu_build_multi( GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); - std::vector params_bufs_list; - std::vector bind_groups; + std::vector params_bufs_list; + std::vector bind_groups; for (size_t i = 0; i < pipelines.size(); i++) { - webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); + wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); std::vector entries = bind_group_entries_list[i]; - // Bindings can be sparse (e.g. 0,1,2,4,5), so params must use max(binding)+1. - uint32_t params_binding_num = 0; - for (const auto & entry : entries) { - if (entry.binding >= params_binding_num) { - params_binding_num = entry.binding + 1; - } - } - entries.push_back({ .binding = params_binding_num, - .buffer = params_bufs.dev_buf, - .offset = 0, - .size = params_bufs.dev_buf.GetSize() }); + uint32_t params_binding_num = entries.size(); + entries.push_back( + { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -709,8 +685,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); for (size_t i = 0; i < params_bufs_list.size(); i++) { - ctx->queue.WriteBuffer(params_bufs_list[i].dev_buf, 0, params_list[i].data(), - params_list[i].size() * sizeof(uint32_t)); + ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); @@ -762,7 +737,8 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & { pipeline }, - { params }, { bind_group_entries }, { { wg_x, wg_y } }); + { std::move(params) }, { std::move(bind_group_entries) }, + { { wg_x, wg_y } }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -1218,17 +1194,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); // Fast-path tiled/subgroup calculations - uint32_t wg_m, wg_n; + uint32_t wg_m; + uint32_t wg_n; if (decisions->use_subgroup_matrix) { uint32_t wg_m_sg_tile = decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; @@ -1246,7 +1223,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else { // legacy - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); @@ -1944,7 +1921,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2502,6 +2484,7 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -3161,8 +3144,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -3187,8 +3169,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, true); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); From 5dd2a4b1629ed7d0d272a5923d5fb9f0622599a3 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 11 Mar 2026 15:54:48 -0700 Subject: [PATCH 25/34] add helper functions --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 137 +++++++++++++++------------ 1 file changed, 76 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aa11b29c63cd..16a031048f6a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1233,6 +1233,77 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, } #ifndef __EMSCRIPTEN__ +static webgpu_pipeline ggml_webgpu_get_or_create_flash_attn_pipeline( + webgpu_context & ctx, const ggml_webgpu_flash_attn_pipeline_key & key) { + auto it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + return it->second; + } + + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size + }; + + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader( + ctx->p, key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn, shader_lib_ctx); + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->flash_attn_pipelines.emplace(key, pipeline); + return pipeline; +} + +static webgpu_pipeline ggml_webgpu_get_or_create_flash_attn_blk_pipeline( + webgpu_context & ctx, uint32_t q_tile, uint32_t kv_tile) { + const ggml_webgpu_flash_attn_blk_pipeline_key blk_key = { + .q_tile = q_tile, + .kv_tile = kv_tile, + }; + auto blk_it = ctx->flash_attn_blk_pipelines.find(blk_key); + if (blk_it != ctx->flash_attn_blk_pipelines.end()) { + return blk_it->second; + } + + ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { + .key = blk_key, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_blk_shader(ctx->p, wgsl_flash_attn_vec_blk, blk_shader_ctx); + webgpu_pipeline blk_pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->flash_attn_blk_pipelines.emplace(blk_key, blk_pipeline); + return blk_pipeline; +} + +static webgpu_pipeline ggml_webgpu_get_or_create_flash_attn_vec_reduce_pipeline( + webgpu_context & ctx, uint32_t head_dim_v, uint32_t wg_size) { + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key reduce_key = { + .head_dim_v = head_dim_v, + .wg_size = wg_size, + }; + auto reduce_it = ctx->flash_attn_vec_reduce_pipelines.find(reduce_key); + if (reduce_it != ctx->flash_attn_vec_reduce_pipelines.end()) { + return reduce_it->second; + } + + ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { + .key = reduce_key, + .max_wg_size = wg_size, + }; + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(ctx->p, wgsl_flash_attn_vec_reduce, reduce_shader_ctx); + webgpu_pipeline reduce_pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->flash_attn_vec_reduce_pipelines.emplace(reduce_key, reduce_pipeline); + return reduce_pipeline; +} + static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1343,31 +1414,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .use_vec = use_vec, }; - webgpu_pipeline pipeline; - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; - - ggml_webgpu_processed_shader processed; - if (use_vec) { - processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn_vec_split, shader_lib_ctx); - } else { - processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - } - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline = ggml_webgpu_get_or_create_flash_attn_pipeline(ctx, key); auto * decisions = static_cast(pipeline.context.get()); @@ -1435,24 +1482,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); ggml_webgpu_create_buffer(ctx->global_ctx->device, blk_buf, blk_size_bytes, wgpu::BufferUsage::Storage, "flash_attn_vec_blk"); - const ggml_webgpu_flash_attn_blk_pipeline_key blk_key = { - .q_tile = decisions->q_tile, - .kv_tile = decisions->kv_tile, - }; - auto blk_it = ctx->flash_attn_blk_pipelines.find(blk_key); - if (blk_it != ctx->flash_attn_blk_pipelines.end()) { - blk_pipeline = blk_it->second; - } else { - ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { - .key = blk_key, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_blk_shader(ctx->p, wgsl_flash_attn_vec_blk, blk_shader_ctx); - blk_pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), - processed.variant.c_str()); - ctx->flash_attn_blk_pipelines.emplace(blk_key, blk_pipeline); - } + blk_pipeline = + ggml_webgpu_get_or_create_flash_attn_blk_pipeline(ctx, decisions->q_tile, decisions->kv_tile); blk_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask @@ -1526,24 +1557,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const uint32_t reduce_wg_size = std::max( 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - const ggml_webgpu_flash_attn_vec_reduce_pipeline_key reduce_key = { - .head_dim_v = (uint32_t) V->ne[0], - .wg_size = reduce_wg_size, - }; - auto reduce_it = ctx->flash_attn_vec_reduce_pipelines.find(reduce_key); - if (reduce_it != ctx->flash_attn_vec_reduce_pipelines.end()) { - reduce_pipeline = reduce_it->second; - } else { - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = reduce_key, - .max_wg_size = reduce_wg_size, - }; - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( - ctx->p, wgsl_flash_attn_vec_reduce, reduce_shader_ctx); - reduce_pipeline = ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), - processed.variant.c_str()); - ctx->flash_attn_vec_reduce_pipelines.emplace(reduce_key, reduce_pipeline); - } + reduce_pipeline = + ggml_webgpu_get_or_create_flash_attn_vec_reduce_pipeline(ctx, (uint32_t) V->ne[0], reduce_wg_size); reduce_params = { (uint32_t) nrows, // nrows From 1e0d856beb2dd464a484a98bed706ef079956534 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 18 Mar 2026 22:39:55 -0700 Subject: [PATCH 26/34] ggml-webgpu: move flash-attn vec pipeline caching back into shader lib --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 126 ++++++------------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 117 ++++------------ .../wgsl-shaders/flash_attn_vec_split.wgsl | 2 - 3 files changed, 66 insertions(+), 179 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e4929ec77ddb..cf82da74b853 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -687,6 +687,14 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_reduce_pipelines; + std::unordered_map + flash_attn_blk_pipelines; std::unordered_map @@ -1468,103 +1476,47 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - - bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % context.sg_mat_n == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = context.src1->type, - .head_dim_qk = (uint32_t) context.src0->ne[0], - .head_dim_v = (uint32_t) context.src2->ne[0], - .kv_direct = kv_direct, - .has_mask = has_mask, - .has_sinks = has_sinks, - .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, - }; - - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { + auto it = flash_attn_pipelines.find(context.key); if (it != flash_attn_pipelines.end()) { return it->second; } - std::vector defines; - std::string variant = "flash_attn"; + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader( + preprocessor, context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn, context); - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - if (key.has_mask) { - defines.push_back("MASK"); - variant += "_mask"; - } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + pipeline.context = processed.decisions; + flash_attn_pipelines[context.key] = pipeline; + return flash_attn_pipelines[context.key]; + } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); - - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, - context.wg_mem_limit_bytes, context.max_subgroup_size }), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (key.kv_direct) { - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + auto it = flash_attn_blk_pipelines.find(context.key); + if (it != flash_attn_blk_pipelines.end()) { + return it->second; } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_blk_pipelines[context.key] = pipeline; + return flash_attn_blk_pipelines[context.key]; + } - auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; + webgpu_pipeline get_flash_attn_vec_reduce_pipeline( + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + auto it = flash_attn_vec_reduce_pipelines.find(context.key); + if (it != flash_attn_vec_reduce_pipelines.end()) { + return it->second; + } - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - flash_attn_pipelines[key] = pipeline; - return flash_attn_pipelines[key]; + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_vec_reduce_pipelines[context.key] = pipeline; + return flash_attn_vec_reduce_pipelines[context.key]; } private: diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 16a031048f6a..c87089857bc7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,7 +8,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -360,23 +359,12 @@ struct webgpu_context_struct { webgpu_global_context global_ctx; std::unique_ptr shader_lib; - pre_wgsl::Preprocessor p; webgpu_buf_pool param_buf_pool; wgpu::Buffer set_rows_dev_error_buf; wgpu::Buffer set_rows_host_error_buf; std::map> cpy_pipelines; // src_type, dst_type - std::unordered_map - flash_attn_pipelines; - std::unordered_map - flash_attn_vec_reduce_pipelines; - std::unordered_map - flash_attn_blk_pipelines; std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split @@ -1232,78 +1220,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } -#ifndef __EMSCRIPTEN__ -static webgpu_pipeline ggml_webgpu_get_or_create_flash_attn_pipeline( - webgpu_context & ctx, const ggml_webgpu_flash_attn_pipeline_key & key) { - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - return it->second; - } - - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; - - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader( - ctx->p, key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn, shader_lib_ctx); - webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - return pipeline; -} - -static webgpu_pipeline ggml_webgpu_get_or_create_flash_attn_blk_pipeline( - webgpu_context & ctx, uint32_t q_tile, uint32_t kv_tile) { - const ggml_webgpu_flash_attn_blk_pipeline_key blk_key = { - .q_tile = q_tile, - .kv_tile = kv_tile, - }; - auto blk_it = ctx->flash_attn_blk_pipelines.find(blk_key); - if (blk_it != ctx->flash_attn_blk_pipelines.end()) { - return blk_it->second; - } - - ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { - .key = blk_key, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_blk_shader(ctx->p, wgsl_flash_attn_vec_blk, blk_shader_ctx); - webgpu_pipeline blk_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->flash_attn_blk_pipelines.emplace(blk_key, blk_pipeline); - return blk_pipeline; -} - -static webgpu_pipeline ggml_webgpu_get_or_create_flash_attn_vec_reduce_pipeline( - webgpu_context & ctx, uint32_t head_dim_v, uint32_t wg_size) { - const ggml_webgpu_flash_attn_vec_reduce_pipeline_key reduce_key = { - .head_dim_v = head_dim_v, - .wg_size = wg_size, - }; - auto reduce_it = ctx->flash_attn_vec_reduce_pipelines.find(reduce_key); - if (reduce_it != ctx->flash_attn_vec_reduce_pipelines.end()) { - return reduce_it->second; - } - - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = reduce_key, - .max_wg_size = wg_size, - }; - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(ctx->p, wgsl_flash_attn_vec_reduce, reduce_shader_ctx); - webgpu_pipeline reduce_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->flash_attn_vec_reduce_pipelines.emplace(reduce_key, reduce_pipeline); - return reduce_pipeline; -} - static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1414,7 +1330,15 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .use_vec = use_vec, }; - webgpu_pipeline pipeline = ggml_webgpu_get_or_create_flash_attn_pipeline(ctx, key); + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, + }; + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1482,8 +1406,15 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); ggml_webgpu_create_buffer(ctx->global_ctx->device, blk_buf, blk_size_bytes, wgpu::BufferUsage::Storage, "flash_attn_vec_blk"); - blk_pipeline = - ggml_webgpu_get_or_create_flash_attn_blk_pipeline(ctx, decisions->q_tile, decisions->kv_tile); + ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { + .key = + { + .q_tile = decisions->q_tile, + .kv_tile = decisions->kv_tile, + }, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); blk_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask @@ -1557,8 +1488,15 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const uint32_t reduce_wg_size = std::max( 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - reduce_pipeline = - ggml_webgpu_get_or_create_flash_attn_vec_reduce_pipeline(ctx, (uint32_t) V->ne[0], reduce_wg_size); + ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { + .key = + { + .head_dim_v = (uint32_t) V->ne[0], + .wg_size = reduce_wg_size, + }, + .max_wg_size = reduce_wg_size, + }; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); reduce_params = { (uint32_t) nrows, // nrows @@ -1609,7 +1547,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -#endif static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index b69fe8994a32..1fb312e414d0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -356,8 +356,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); // accumulate q block * k block into registers across the entire KV tile - // TODO: this loop seems to be the current largest bottleneck - // this bracket exists to scope the lifetime of variables, reducing register pressure if (!skip_tile) { let num_of_threads = subgroup_size / VEC_NE; let tx = sg_inv_id % num_of_threads; From 88bf352513b1b4607bf985a97a7b0a49c32ba6b5 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 18 Mar 2026 22:46:07 -0700 Subject: [PATCH 27/34] ggml-webgpu: remove duplicate functions --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 212 ++++++++---------- 1 file changed, 88 insertions(+), 124 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index cf82da74b853..d1b8616c1d5f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -453,125 +453,6 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; -} - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn"; - - switch (context.key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(context.key.kv_type); - - if (context.key.has_mask) { - defines.push_back("MASK"); - variant += "_mask"; - } - if (context.key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (context.key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - - if (context.key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - if (context.key.has_mask && context.key.use_vec) { - defines.push_back("BLK"); - variant += "_blk"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - // For now these are not part of the variant name. - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - // Add chosen Q/KV tile sizes. - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.use_vec) { - q_tile = 1; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context); - kv_tile = std::max(context.sg_mat_n, std::min(32u, max_kv_tile)); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; - const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - } - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - // Avoids bounds checks for direct KV loads. - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } - } - - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - - uint32_t wg_size = 0; - if (context.key.use_vec) { - // Keep vec-split to exactly one subgroup to preserve thread mapping. - wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - } else { - wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} - /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { @@ -1482,13 +1363,96 @@ class ggml_webgpu_shader_lib { return it->second; } - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader( - preprocessor, context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn, context); + std::vector defines; + std::string variant = "flash_attn"; + + switch (context.key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(context.key.kv_type); + + if (context.key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (context.key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (context.key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (context.key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + if (context.key.has_mask && context.key.use_vec) { + defines.push_back("BLK"); + variant += "_blk"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = + std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (context.key.use_vec) { + q_tile = 1; + kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + } + if (context.key.kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + uint32_t wg_size = 0; + if (context.key.use_vec) { + wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + } else { + wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - pipeline.context = processed.decisions; - flash_attn_pipelines[context.key] = pipeline; + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + pipeline.context = decisions; + flash_attn_pipelines[context.key] = pipeline; return flash_attn_pipelines[context.key]; } From 5c2fefeafc041036d23ec595e5f4bbf55f83b8a0 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 18 Mar 2026 23:28:36 -0700 Subject: [PATCH 28/34] ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 124 ++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c87089857bc7..23fbd609d4b5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1315,9 +1315,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; + const uint32_t vec_nwg_cap = + std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1358,11 +1358,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, nwg <<= 1; } nwg = std::min(nwg, vec_nwg_cap); - GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - // For a single split workgroup there is nothing to merge. - // Let vec split write final dst directly and skip reduce. const bool use_vec_reduce = nwg > 1u; GGML_ASSERT(nrows <= UINT32_MAX); @@ -1371,19 +1368,21 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, wgpu::Buffer tmp_buf = {}; uint64_t tmp_bind_offset = 0; uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); if (use_vec_reduce) { const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - tmp_stats_base = tmp_data_elems; const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const uint64_t tmp_total_elems = tmp_data_elems + tmp_stats_elems; - tmp_size_bytes = ROUNDUP_POW2(tmp_total_elems * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); GGML_ASSERT(tmp_stats_base <= UINT32_MAX); - - ggml_webgpu_create_buffer(ctx->global_ctx->device, tmp_buf, tmp_size_bytes, wgpu::BufferUsage::Storage, - "flash_attn_vec_tmp"); - tmp_bind_offset = 0; + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); } else { // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. tmp_buf = ggml_webgpu_tensor_buf(dst); @@ -1397,15 +1396,13 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, if (use_blk) { GGML_ASSERT(has_mask); - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_buf = ggml_webgpu_tensor_buf(dst); const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - ggml_webgpu_create_buffer(ctx->global_ctx->device, blk_buf, blk_size_bytes, wgpu::BufferUsage::Storage, - "flash_attn_vec_blk"); + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { .key = { @@ -1429,8 +1426,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .buffer = ggml_webgpu_tensor_buf(mask), .offset = ggml_webgpu_tensor_align_offset(ctx, mask), .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, - { .binding = 1, .buffer = blk_buf, .offset = 0, .size = blk_size_bytes }, + { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); } std::vector split_params = params; @@ -1472,7 +1470,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, } if (use_blk) { split_entries.push_back( - { .binding = split_binding_index++, .buffer = blk_buf, .offset = 0, .size = blk_size_bytes }); + { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes }); } split_entries.push_back( { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); @@ -1509,7 +1507,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, }; reduce_entries = { - { .binding = 0, .buffer = tmp_buf, .offset = 0, .size = tmp_size_bytes }, + { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), @@ -2751,6 +2749,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const ggml_tensor * sinks = tensor->src[4]; + if (Q && K && V) { + GGML_UNUSED(sinks); + const bool kv_direct = (K->type == GGML_TYPE_F16) && + (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = + (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (V->type == K->type); + if (use_vec) { + const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + const size_t limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const size_t q_tile = sg_mat_m; + const size_t base_q_bytes = + (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!kv_direct) { + bytes_per_kv += std::max(Q->ne[0], V->ne[0]); + } + if (mask != nullptr) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + uint32_t kv_tile = + ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; + kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); + kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; + if (kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= sg_mat_n; + } + } + + const uint32_t vec_nwg_cap = std::max( + 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + + const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2( + (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = + (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + } + break; default: break; } From 59aa7d88fc57fa9f9877eac40d83e40efb70075f Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 18 Mar 2026 23:35:10 -0700 Subject: [PATCH 29/34] ggml-webgpu: revert unrelated change --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 23fbd609d4b5..3d88f3ab95a1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3174,7 +3174,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -3199,7 +3200,8 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, true); + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); From cac85006f60c892ff02ca3a56d6db8714bdf2d78 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 18 Mar 2026 23:36:18 -0700 Subject: [PATCH 30/34] ggml-webgpu: revert deleted comment --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3d88f3ab95a1..112891cc2565 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2488,6 +2488,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.clear(); } + // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. if (contains_set_rows) { wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, From 4e0100bb911bb61c333e0c2cdac6dfb6e8e4717a Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 31 Mar 2026 21:22:26 -0700 Subject: [PATCH 31/34] disable uniformity check --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 3 ++- ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl | 1 + ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 51a851df85ea..3a476f09be33 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -99,7 +99,8 @@ struct ggml_webgpu_processed_shader { std::string wgsl; std::string variant; std::shared_ptr decisions; -} +}; + struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t block_size; uint32_t tokens_per_wg; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 82556d060bf8..82d072be73a0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -1,3 +1,4 @@ +diagnostic(off, subgroup_uniformity); enable f16; #define Q_TILE 1 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index b71d13d4e499..9a0de82a56a4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -1,3 +1,4 @@ +diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; From 56fee6e209f71f3284863aedb746097436675910 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Tue, 31 Mar 2026 21:28:45 -0700 Subject: [PATCH 32/34] remove unnecessary change --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 9653c8d21fc4..97ca1baf768b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -341,7 +341,6 @@ struct webgpu_global_context_struct { this->debug_dev_buf = nullptr; } #endif - } }; @@ -365,13 +364,6 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_dev_error_buf; wgpu::Buffer set_rows_host_error_buf; - std::map> cpy_pipelines; // src_type, dst_type - - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split - - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - size_t memset_bytes_per_thread; }; @@ -665,6 +657,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < params_bufs_list.size(); i++) { ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } + #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { From 29c09c2e1ff603b9cff576c47c338ac926e67eb2 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 1 Apr 2026 20:20:24 -0700 Subject: [PATCH 33/34] Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl --- ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 1fb312e414d0..a52575871ae2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -371,7 +371,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let kv_idx = kv_base + ty; var partial_sum: f32 = 0.0; let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; - if (kv_valid) { for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { let q_off = local_q_row_offset + i * 4u; From f40c9e75c2101fed239d665a4f33a49c0b4071d4 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 1 Apr 2026 20:20:30 -0700 Subject: [PATCH 34/34] Update ggml/src/ggml-webgpu/ggml-webgpu.cpp --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 97ca1baf768b..f1a715a2c2e7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -657,7 +657,6 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < params_bufs_list.size(); i++) { ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } - #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {