Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 129 additions & 64 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ struct webgpu_capabilities {
uint32_t sg_mat_k = 0;

uint32_t subgroup_size = 0;
uint32_t min_subgroup_size = 0;
uint32_t max_subgroup_size = 0;
size_t memset_bytes_per_thread;
};
Expand Down Expand Up @@ -1442,6 +1443,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;

// Get or create pipeline
Expand Down Expand Up @@ -1750,6 +1752,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx,
shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size;
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(
shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment);
Expand Down Expand Up @@ -3469,6 +3472,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;

const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
Expand Down Expand Up @@ -3667,8 +3671,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
#endif
ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config;

// For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
// Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
// Runtime subgroup size can be any supported size in this range. Shaders
// that allocate per-lane register arrays must size them for the minimum.
ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize;
ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize;
// Initialize device
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
Expand Down Expand Up @@ -4024,11 +4029,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
shader_lib_ctx.dst = const_cast<ggml_tensor *>(op);
shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups;
shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix;
shader_lib_ctx.max_wg_size =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
shader_lib_ctx.wg_mem_limit_bytes =
ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k;
shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size;
shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size;

const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(
Expand All @@ -4040,19 +4048,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
break;
}
if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) {
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
break;
}

if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) {
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
Expand All @@ -4063,9 +4071,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
supports_op = false;
break;
}
const size_t min_bytes =
ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0],
(uint32_t) src2->ne[0], has_mask, decisions.kv_direct);
const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask,
decisions.kv_direct, decisions.path);
if (min_bytes > limit_bytes) {
supports_op = false;
}
Expand Down
87 changes: 54 additions & 33 deletions ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
enable f16;
enable subgroups;

#ifdef Q_F16
#define Q_TYPE f16
#else
#define Q_TYPE f32
#endif

#ifdef KV_F32
#define KV_TYPE f32
#else
#define KV_TYPE f16
#endif

#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif

#define HEAD_DIM_QK 64
#define HEAD_DIM_V 64
#define KV_STAGE_STRIDE 64
#define Q_TILE 4
#define KV_TILE 64
#define WG_SIZE 128
#ifndef MIN_SUBGROUP_SIZE
#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE
#endif

struct Params {
offset_q: u32,
Expand Down Expand Up @@ -41,13 +62,13 @@ struct Params {
m1: f32,
};

@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>;
#ifdef KV_OVERLAP
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
#define V K
#else
@group(0) @binding(1) var<storage, read_write> K: array<vec4<f16>>;
@group(0) @binding(2) var<storage, read_write> V: array<vec4<f16>>;
@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
#endif

#if defined(MASK) && defined(SINKS)
Expand Down Expand Up @@ -92,17 +113,17 @@ struct Params {
#endif
#endif

@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;

const FLOAT_MIN: f32 = -1.0e9;
const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u;
const V_CHUNKS: u32 = HEAD_DIM_V / 4u;
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE;
const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;
const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE;

var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
var<workgroup> kv_shmem: array<f16, KV_TILE * KV_STAGE_STRIDE>;
var<workgroup> q_shmem: array<f32, Q_TILE * HEAD_DIM_QK>;
var<workgroup> kv_shmem: array<f32, KV_TILE * KV_STAGE_STRIDE>;
var<workgroup> p_shmem: array<f32, Q_TILE * KV_TILE>;

@compute @workgroup_size(WG_SIZE)
Expand Down Expand Up @@ -158,10 +179,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let q_col = elem_idx % HEAD_DIM_QK;
let head_q_row = q_row_start + q_tile_row;
let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
q_shmem[elem_idx] = f16(select(
q_shmem[elem_idx] = select(
0.0,
Q[global_q_row_offset + q_col] * params.scale,
head_q_row < params.seq_len_q));
f32(Q[global_q_row_offset + q_col]) * params.scale,
head_q_row < params.seq_len_q);
}

workgroupBarrier();
Expand Down Expand Up @@ -192,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u;
let k4 = K[k_vec_index];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
kv_shmem[kv_off + 0u] = k4.x;
kv_shmem[kv_off + 1u] = k4.y;
kv_shmem[kv_off + 2u] = k4.z;
kv_shmem[kv_off + 3u] = k4.w;
kv_shmem[kv_off + 0u] = f32(k4.x);
kv_shmem[kv_off + 1u] = f32(k4.y);
kv_shmem[kv_off + 2u] = f32(k4.z);
kv_shmem[kv_off + 3u] = f32(k4.w);
}

workgroupBarrier();
Expand All @@ -213,16 +234,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) {
let q_off = q_base + chunk * 4u;
let qv = vec4<f32>(
f32(q_shmem[q_off + 0u]),
f32(q_shmem[q_off + 1u]),
f32(q_shmem[q_off + 2u]),
f32(q_shmem[q_off + 3u]));
q_shmem[q_off + 0u],
q_shmem[q_off + 1u],
q_shmem[q_off + 2u],
q_shmem[q_off + 3u]);
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
let kv = vec4<f32>(
f32(kv_shmem[kv_off + 0u]),
f32(kv_shmem[kv_off + 1u]),
f32(kv_shmem[kv_off + 2u]),
f32(kv_shmem[kv_off + 3u]));
kv_shmem[kv_off + 0u],
kv_shmem[kv_off + 1u],
kv_shmem[kv_off + 2u],
kv_shmem[kv_off + 3u]);
dot_val += dot(qv, kv);
}
#ifdef LOGIT_SOFTCAP
Expand Down Expand Up @@ -264,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u;
let v4 = V[v_vec_index];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
kv_shmem[kv_off + 0u] = v4.x;
kv_shmem[kv_off + 1u] = v4.y;
kv_shmem[kv_off + 2u] = v4.z;
kv_shmem[kv_off + 3u] = v4.w;
kv_shmem[kv_off + 0u] = f32(v4.x);
kv_shmem[kv_off + 1u] = f32(v4.y);
kv_shmem[kv_off + 2u] = f32(v4.z);
kv_shmem[kv_off + 3u] = f32(v4.w);
}

workgroupBarrier();
Expand All @@ -288,10 +309,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
let p = p_shmem[subgroup_p_offset + kv_local];
let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u;
let v4 = vec4<f32>(
f32(kv_shmem[kv_off + 0u]),
f32(kv_shmem[kv_off + 1u]),
f32(kv_shmem[kv_off + 2u]),
f32(kv_shmem[kv_off + 3u]));
kv_shmem[kv_off + 0u],
kv_shmem[kv_off + 1u],
kv_shmem[kv_off + 2u],
kv_shmem[kv_off + 3u]);
acc += p * v4;
}
out_regs[reg_idx] = acc;
Expand Down Expand Up @@ -324,7 +345,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
continue;
}
let dst_vec_index = (row_base + chunk * 4u) >> 2u;
dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum;
dst[dst_vec_index] = vec4<DST_TYPE>(out_regs[reg_idx] * inv_exp_sum);
}
}
}
10 changes: 8 additions & 2 deletions ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ diagnostic(off, subgroup_uniformity);
enable f16;
enable subgroups;

#ifdef DST_F16
#define DST_TYPE f16
#else
#define DST_TYPE f32
#endif

// Default values
#define HEAD_DIM_V 64
#define WG_SIZE 128
Expand All @@ -17,7 +23,7 @@ struct Params {
};

@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read_write> dst: array<vec4<DST_TYPE>>;
@group(0) @binding(2) var<uniform> params: Params;

const FLOAT_MIN: f32 = -1.0e9;
Expand Down Expand Up @@ -72,7 +78,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,

if (thread == 0u) {
let dst_vec_index = (row_base + elem_base) >> 2u;
dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
dst[dst_vec_index] = vec4<DST_TYPE>(vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s);
}
}
}
Loading
Loading