From 210766ce1721d740209ddf8ed2efadd06cf10601 Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Tue, 5 May 2026 17:15:58 -0400 Subject: [PATCH 1/9] fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 174 ++++++++++++------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 20 +- .../wgsl-shaders/flash_attn_tile.wgsl | 78 +++++--- .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 10 +- .../wgsl-shaders/flash_attn_vec_split.wgsl | 112 ++++++----- 5 files changed, 248 insertions(+), 146 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c6dc2c21147..7662a099127 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -531,7 +531,9 @@ enum ggml_webgpu_flash_attn_path : uint32_t { }; struct ggml_webgpu_flash_attn_pipeline_key { + ggml_type q_type; ggml_type kv_type; + ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; @@ -542,16 +544,19 @@ struct ggml_webgpu_flash_attn_pipeline_key { uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { - return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && - has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path; + return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && + kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_type); ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); @@ -611,7 +616,9 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ } ggml_webgpu_flash_attn_pipeline_key key = {}; + key.q_type = context.src0->type; key.kv_type = context.src1->type; + key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; @@ -624,8 +631,9 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { - uint32_t head_dim_v; - uint32_t wg_size; + uint32_t head_dim_v; + uint32_t wg_size; + ggml_type dst_type; }; struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { @@ -633,13 +641,14 @@ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.wg_size); + ggml_webgpu_hash_combine(seed, key.dst_type); return seed; } }; inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { - return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type; } struct ggml_webgpu_flash_attn_blk_pipeline_key { @@ -662,19 +671,32 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, - bool kv_direct) { + bool kv_direct, + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; - f16_elems += q_tile * head_dim_qk; // q_shmem + if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + f32_elems += head_dim_qk; // q_shmem + if (!kv_direct) { + f32_elems += kv_tile * max_head_dim; // kv_shmem + } + f32_elems += head_dim_v; // o_shmem + if (has_mask) { + f32_elems += kv_tile; // mask_shmem + } + f32_elems += kv_tile; // inter_shmem + return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; + } + f32_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { - f16_elems += kv_tile * max_head_dim; // kv_shmem + f32_elems += kv_tile * max_head_dim; // kv_shmem } - f16_elems += q_tile * head_dim_v; // o_shmem + f32_elems += q_tile * head_dim_v; // o_shmem if (has_mask) { - f16_elems += q_tile * kv_tile; // mask_shmem + f32_elems += q_tile * kv_tile; // mask_shmem } - f16_elems += q_tile * kv_tile; // inter_shmem + f32_elems += q_tile * kv_tile; // inter_shmem f32_elems += q_tile; // row_max_shmem f32_elems += q_tile; // exp_sum_shmem return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; @@ -684,27 +706,27 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const ggml_webgpu_flash_attn_pipeline_key & key) { const size_t limit_bytes = context.wg_mem_limit_bytes; uint32_t q_tile = context.sg_mat_m; - uint32_t kv_granularity = context.sg_mat_n; + uint32_t kv_granularity = std::max(1u, context.sg_mat_n); if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; - kv_granularity = std::max(1u, context.max_subgroup_size); + kv_granularity = 1u; } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { q_tile = 1u; kv_granularity = 8u; } - const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!key.kv_direct) { - bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v, + key.has_mask, key.kv_direct, key.path); + if (limit_bytes <= base_q_bytes) { + return 0; } - if (key.has_mask) { - bytes_per_kv += q_tile; + const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v, + key.has_mask, key.kv_direct, key.path); + const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; + if (bytes_per_kv == 0) { + return 0; } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / kv_granularity) * kv_granularity; + const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( @@ -731,14 +753,18 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); + const bool tile_can_dispatch_all_q_rows = + context.max_subgroup_size > 0 && + context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + tile_can_dispatch_all_q_rows && !use_vec; decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : @@ -777,14 +803,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? std::min(64u, max_kv_tile) : std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : - std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size); - decisions.kv_tile = - std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity); + decisions.wg_size = + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::min(std::max(1u, context.max_wg_size), + std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * std::max(1u, context.max_subgroup_size))) : + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + + if (decisions.kv_tile == 0) { + return decisions; } if (decisions.kv_direct) { @@ -1577,7 +1604,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1694,10 +1721,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1805,13 +1832,13 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_subgroup_matrix = context.supports_subgroup_matrix; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -2074,10 +2101,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2194,10 +2221,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_vec_pipelines.find(key); if (it != mul_mat_id_vec_pipelines.end()) { @@ -2586,6 +2613,30 @@ class ggml_webgpu_shader_lib { } variant += std::string("_") + ggml_type_name(key.kv_type); + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; + default: + GGML_ABORT("Unsupported Q type for flash attention shader"); + } + variant += std::string("_q") + ggml_type_name(key.q_type); + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + if (key.has_mask) { defines.push_back("MASK"); if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { @@ -2677,6 +2728,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.dst_type = context.dst->type; key.wg_size = context.max_wg_size; auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { @@ -2686,6 +2738,18 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "flash_attn_vec_reduce"; + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention vec reduce shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 12f60a9900e..c9e42f37d7b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4024,6 +4024,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const shader_lib_ctx.dst = const_cast(op); shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; @@ -4040,9 +4042,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } @@ -4050,9 +4052,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } @@ -4063,9 +4065,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = false; break; } - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 37ea23b80c8..979fadafb55 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -1,6 +1,24 @@ enable f16; enable subgroups; +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 #define KV_STAGE_STRIDE 64 @@ -41,12 +59,12 @@ struct Params { m1: f32, }; -@group(0) @binding(0) var Q: array; +@group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP -@group(0) @binding(1) var K: array>; +@group(0) @binding(1) var K: array>; #define V K #else -@group(0) @binding(1) var K: array>; +@group(0) @binding(1) var K: array>; @group(0) @binding(2) var V: array>; #endif @@ -92,7 +110,7 @@ struct Params { #endif #endif -@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; const FLOAT_MIN: f32 = -1.0e9; @@ -101,8 +119,8 @@ const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; -var q_shmem: array; -var kv_shmem: array; +var q_shmem: array; +var kv_shmem: array; var p_shmem: array; @compute @workgroup_size(WG_SIZE) @@ -158,10 +176,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_col = elem_idx % HEAD_DIM_QK; let head_q_row = q_row_start + q_tile_row; let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; - q_shmem[elem_idx] = f16(select( + q_shmem[elem_idx] = select( 0.0, - Q[global_q_row_offset + q_col] * params.scale, - head_q_row < params.seq_len_q)); + f32(Q[global_q_row_offset + q_col]) * params.scale, + head_q_row < params.seq_len_q); } workgroupBarrier(); @@ -192,10 +210,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = k4.x; - kv_shmem[kv_off + 1u] = k4.y; - kv_shmem[kv_off + 2u] = k4.z; - kv_shmem[kv_off + 3u] = k4.w; + kv_shmem[kv_off + 0u] = f32(k4.x); + kv_shmem[kv_off + 1u] = f32(k4.y); + kv_shmem[kv_off + 2u] = f32(k4.z); + kv_shmem[kv_off + 3u] = f32(k4.w); } workgroupBarrier(); @@ -213,16 +231,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { let q_off = q_base + chunk * 4u; let qv = vec4( - f32(q_shmem[q_off + 0u]), - f32(q_shmem[q_off + 1u]), - f32(q_shmem[q_off + 2u]), - f32(q_shmem[q_off + 3u])); + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; let kv = vec4( - f32(kv_shmem[kv_off + 0u]), - f32(kv_shmem[kv_off + 1u]), - f32(kv_shmem[kv_off + 2u]), - f32(kv_shmem[kv_off + 3u])); + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); dot_val += dot(qv, kv); } #ifdef LOGIT_SOFTCAP @@ -264,10 +282,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = v4.x; - kv_shmem[kv_off + 1u] = v4.y; - kv_shmem[kv_off + 2u] = v4.z; - kv_shmem[kv_off + 3u] = v4.w; + kv_shmem[kv_off + 0u] = f32(v4.x); + kv_shmem[kv_off + 1u] = f32(v4.y); + kv_shmem[kv_off + 2u] = f32(v4.z); + kv_shmem[kv_off + 3u] = f32(v4.w); } workgroupBarrier(); @@ -288,10 +306,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let p = p_shmem[subgroup_p_offset + kv_local]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; let v4 = vec4( - f32(kv_shmem[kv_off + 0u]), - f32(kv_shmem[kv_off + 1u]), - f32(kv_shmem[kv_off + 2u]), - f32(kv_shmem[kv_off + 3u])); + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); acc += p * v4; } out_regs[reg_idx] = acc; @@ -324,7 +342,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } let dst_vec_index = (row_base + chunk * 4u) >> 2u; - dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + dst[dst_vec_index] = vec4(out_regs[reg_idx] * inv_exp_sum); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index 9a0de82a56a..1091d744073 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -2,6 +2,12 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + // Default values #define HEAD_DIM_V 64 #define WG_SIZE 128 @@ -17,7 +23,7 @@ struct Params { }; @group(0) @binding(0) var tmp: array; -@group(0) @binding(1) var dst: array>; +@group(0) @binding(1) var dst: array>; @group(0) @binding(2) var params: Params; const FLOAT_MIN: f32 = -1.0e9; @@ -72,7 +78,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (thread == 0u) { let dst_vec_index = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + dst[dst_vec_index] = vec4(vec4(sum_x, sum_y, sum_z, sum_w) * inv_s); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index b1e234784a8..30ebbebe772 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -8,6 +8,18 @@ enable subgroups; #define KV_TYPE f16 #endif +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 @@ -89,7 +101,7 @@ struct Params { nwg: u32, }; -@group(0) @binding(0) var Q: array; +@group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP #if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(1) var K: array; @@ -191,41 +203,41 @@ struct Params { @group(0) @binding(BLK_BINDING) var blk: array; #endif @group(0) @binding(TMP_BINDING) var tmp: array; -@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; // Just a very small float value. const FLOAT_MIN: f32 = -1.0e9; -var q_shmem: array; +var q_shmem: array; #ifndef KV_DIRECT const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); // we can reuse the same shmem for K and V since we only need one at a time -var kv_shmem: array; +var kv_shmem: array; #endif -var o_shmem: array; +var o_shmem: array; #ifdef MASK // storage for mask values -var mask_shmem: array; +var mask_shmem: array; #endif // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; // Storage for row max and exp sum during online softmax fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx]) * params.scale, + inter_shmem[kv_idx] * params.scale, kv_idx < KV_TILE); #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE); + var mask_val = select(0.0, mask_shmem[kv_idx], kv_idx < KV_TILE); v += select(mask_val, slope * mask_val, has_bias); } #endif @@ -289,10 +301,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load the single Q row into shared memory for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; - q_shmem[elem_idx] = f16(select( + q_shmem[elem_idx] = select( 0.0, - Q[global_q_row_offset + elem_idx], - q_row_start < params.seq_len_q)); + f32(Q[global_q_row_offset + elem_idx]), + q_row_start < params.seq_len_q); } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { @@ -308,7 +320,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let blk_state = blk_state_local; let skip_tile = blk_state == 0u; for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = f16(0.0); + inter_shmem[elem_idx] = 0.0; } // load k tile into shared memory @@ -331,8 +343,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); + let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_lo; kv_shmem[row_offset + idx + 16u] = q_hi; @@ -359,7 +371,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; + let q_val = f32(q_byte) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_val; } @@ -377,10 +389,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; let vec_idx = (global_k_row_offset + k_col) >> 2u; let k4 = select(vec4(0.0), K[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f16(k4.x); - kv_shmem[elem_idx + 1u] = f16(k4.y); - kv_shmem[elem_idx + 2u] = f16(k4.z); - kv_shmem[elem_idx + 3u] = f16(k4.w); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); } #endif @@ -401,20 +413,20 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_off = i * 4u; let qv = vec4( - f32(q_shmem[q_off + 0u]), - f32(q_shmem[q_off + 1u]), - f32(q_shmem[q_off + 2u]), - f32(q_shmem[q_off + 3u])); + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); #ifdef KV_DIRECT let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); let kv = vec4(K[idx >> 2u]); #else let idx = kv_idx * HEAD_DIM_QK + (i * 4u); let kv = vec4( - f32(kv_shmem[idx + 0u]), - f32(kv_shmem[idx + 1u]), - f32(kv_shmem[idx + 2u]), - f32(kv_shmem[idx + 3u])); + kv_shmem[idx + 0u], + kv_shmem[idx + 1u], + kv_shmem[idx + 2u], + kv_shmem[idx + 3u]); #endif partial_sum += dot(qv, kv); } @@ -435,7 +447,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); if (tx == 0u && kv_valid) { - inter_shmem[kv_idx] = f16(sum_bcast); + inter_shmem[kv_idx] = sum_bcast; } } } @@ -450,7 +462,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_k_col = kv_tile + elem_idx; let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; let mask_idx = mask_global_offset + global_k_col; - mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + mask_shmem[elem_idx] = select(0.0f, f32(mask[mask_idx]), mask_in_bounds); } } #else @@ -483,7 +495,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); total_exp_term += subgroupAdd(cur_p); if (kv_idx < KV_TILE) { - inter_shmem[kv_idx] = f16(cur_p); + inter_shmem[kv_idx] = cur_p; } } @@ -493,7 +505,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum = exp_sum * cur_exp + total_exp_term; for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp); + o_shmem[elem_idx] = o_shmem[elem_idx] * cur_exp; } } @@ -517,8 +529,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); + let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_lo; kv_shmem[row_offset + idx + 16u] = q_hi; @@ -545,7 +557,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; + let q_val = f32(q_byte) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_val; } @@ -563,10 +575,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; let vec_idx = (global_v_row_offset + v_col) >> 2u; let v4 = select(vec4(0.0), V[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f16(v4.x); - kv_shmem[elem_idx + 1u] = f16(v4.y); - kv_shmem[elem_idx + 2u] = f16(v4.z); - kv_shmem[elem_idx + 3u] = f16(v4.w); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); } #endif @@ -589,17 +601,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } - let p = f32(inter_shmem[kv_idx]); + let p = inter_shmem[kv_idx]; #ifdef KV_DIRECT let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; let v4 = vec4(V[v_idx >> 2u]); #else let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; let v4 = vec4( - f32(kv_shmem[v_idx + 0u]), - f32(kv_shmem[v_idx + 1u]), - f32(kv_shmem[v_idx + 2u]), - f32(kv_shmem[v_idx + 3u])); + kv_shmem[v_idx + 0u], + kv_shmem[v_idx + 1u], + kv_shmem[v_idx + 2u], + kv_shmem[v_idx + 3u]); #endif lo += p * v4; } @@ -630,10 +642,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (ty_pv == 0u) { let elem_base = vec_col * 4u; - o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x); - o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y); - o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z); - o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w); + o_shmem[elem_base + 0u] = o_shmem[elem_base + 0u] + lo_x; + o_shmem[elem_base + 1u] = o_shmem[elem_base + 1u] + lo_y; + o_shmem[elem_base + 2u] = o_shmem[elem_base + 2u] + lo_z; + o_shmem[elem_base + 3u] = o_shmem[elem_base + 3u] + lo_w; } } } @@ -660,7 +672,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum = exp_sum * max_exp + sink_exp_sum; for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp); + o_shmem[elem_idx] = o_shmem[elem_idx] * max_exp; } } workgroupBarrier(); @@ -681,7 +693,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); let dst_vec_index: u32 = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = v; + dst[dst_vec_index] = vec4(v); } } else { let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; From 2380a06202c0ea6fef69744d5d6a44665d297843 Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Tue, 5 May 2026 20:52:26 -0400 Subject: [PATCH 2/9] fix(unary): correct the gelu, gelu quick and gelu erf functions --- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 34 ++++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index b8f1bca1284..680c47045e4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -50,6 +50,20 @@ struct Params { @group(0) @binding(PARAMS_BINDING) var params: Params; +fn erf_approx(x: f32) -> f32 { + let s = select(-1.0, 1.0, x >= 0.0); + let ax = abs(x); + + let t = 1.0 / (1.0 + 0.3275911 * ax); + + let y = 1.0 - + (((((1.061405429 * t - 1.453152027) * t + 1.421413741) * t + - 0.284496736) * t + 0.254829592) * t) * + exp(-ax * ax); + + return s * y; +} + @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { @@ -124,27 +138,13 @@ fn main(@builtin(global_invocation_id) gid: vec3) { min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); #endif #ifdef GELU - let res = 0.5 * src[params.offset_src + src_idx] * - (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * - (src[params.offset_src + src_idx] + - 0.044715 * pow(src[params.offset_src + src_idx], 3.0)), - -9.010913, 9.010913))); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(0.7978845608 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]))); #endif #ifdef GELU_QUICK - let res = src[params.offset_src + src_idx] * 0.5 * - (1.0 + tanh(clamp(0.79788456 * - (src[params.offset_src + src_idx] + - 0.044715 * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), - -9.010913, 9.010913))); + let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(-1.702 * src[params.offset_src + src_idx]))); #endif #ifdef GELU_ERF - let res = 0.5 * src[params.offset_src + src_idx] * - (1.0 + tanh(clamp(0.79788456 * - (src[params.offset_src + src_idx] + - 0.044715 * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), - -9.010913, 9.010913))); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.70710678)); #endif #ifdef XIELU let val = f32(src[params.offset_src + src_idx]); From f5f940fc36cfcb4ade2372f4f5a0ee0fe97db521 Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Tue, 5 May 2026 23:14:49 -0400 Subject: [PATCH 3/9] fix(flash-attn-tile): fix the hardcode v type --- ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 979fadafb55..4bd4ec670d4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -65,7 +65,7 @@ struct Params { #define V K #else @group(0) @binding(1) var K: array>; -@group(0) @binding(2) var V: array>; +@group(0) @binding(2) var V: array>; #endif #if defined(MASK) && defined(SINKS) From f7b1560df0aeb580b593171f5588d40ce7b807cc Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Thu, 7 May 2026 12:25:42 -0400 Subject: [PATCH 4/9] fix(flash_attn): fix tile path --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 92 ++++++++++++++----- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 10 +- .../wgsl-shaders/flash_attn_tile.wgsl | 7 +- 3 files changed, 80 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 7662a099127..c6a350526de 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -91,6 +91,7 @@ struct ggml_webgpu_shader_lib_context { uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; }; @@ -542,12 +543,23 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_sinks; bool uses_logit_softcap; uint32_t path; + uint32_t q_tile; + uint32_t kv_tile; + uint32_t wg_size; + uint32_t min_subgroup_size; + uint32_t max_subgroup_size; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && path == other.path; + uses_logit_softcap == other.uses_logit_softcap && path == other.path && q_tile == other.q_tile && + kv_tile == other.kv_tile && wg_size == other.wg_size && min_subgroup_size == other.min_subgroup_size && + max_subgroup_size == other.max_subgroup_size && sg_mat_m == other.sg_mat_m && + sg_mat_n == other.sg_mat_n && sg_mat_k == other.sg_mat_k; } }; @@ -565,6 +577,14 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); ggml_webgpu_hash_combine(seed, key.path); + ggml_webgpu_hash_combine(seed, key.q_tile); + ggml_webgpu_hash_combine(seed, key.kv_tile); + ggml_webgpu_hash_combine(seed, key.wg_size); + ggml_webgpu_hash_combine(seed, key.min_subgroup_size); + ggml_webgpu_hash_combine(seed, key.max_subgroup_size); + ggml_webgpu_hash_combine(seed, key.sg_mat_m); + ggml_webgpu_hash_combine(seed, key.sg_mat_n); + ggml_webgpu_hash_combine(seed, key.sg_mat_k); return seed; } }; @@ -581,6 +601,20 @@ struct ggml_webgpu_flash_attn_decisions { inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; +inline uint32_t ggml_webgpu_effective_min_subgroup_size(const ggml_webgpu_shader_lib_context & context) { + if (context.min_subgroup_size > 0) { + return context.min_subgroup_size; + } + return std::max(1u, context.max_subgroup_size); +} + +inline uint32_t ggml_webgpu_effective_max_subgroup_size(const ggml_webgpu_shader_lib_context & context) { + if (context.max_subgroup_size > 0) { + return context.max_subgroup_size; + } + return std::max(1u, context.min_subgroup_size); +} + inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { @@ -600,14 +634,14 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_ } inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context, - uint32_t path) { + const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_decisions & decisions) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; bool kv_direct = false; - if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { kv_direct_align = context.sg_mat_k; } kv_direct = (context.src1->type == GGML_TYPE_F16) && @@ -626,7 +660,15 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.path = path; + key.q_tile = decisions.q_tile; + key.kv_tile = decisions.kv_tile; + key.wg_size = decisions.wg_size; + key.min_subgroup_size = ggml_webgpu_effective_min_subgroup_size(context); + key.max_subgroup_size = ggml_webgpu_effective_max_subgroup_size(context); + key.sg_mat_m = context.sg_mat_m; + key.sg_mat_n = context.sg_mat_n; + key.sg_mat_k = context.sg_mat_k; + key.path = decisions.path; return key; } @@ -738,6 +780,8 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( const auto * V = context.src2; GGML_ASSERT(K != nullptr); GGML_ASSERT(V != nullptr); + const uint32_t min_subgroup_size = ggml_webgpu_effective_min_subgroup_size(context); + const uint32_t max_subgroup_size = ggml_webgpu_effective_max_subgroup_size(context); const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { constexpr uintptr_t ptr_base_addr = 0x1000u; @@ -758,16 +802,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = - context.max_subgroup_size > 0 && - context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; + max_subgroup_size > 0 && context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * max_subgroup_size; const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows && !use_vec; + tile_can_dispatch_all_q_rows; - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + decisions.path = use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : GGML_WEBGPU_FLASH_ATTN_PATH_NONE; @@ -775,7 +818,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( return decisions; } - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); decisions.kv_direct = key.kv_direct; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); // invalidate if even the smallest kv_tile doesn't fit in shared memory @@ -788,7 +831,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = 1u; decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.wg_size = std::max(1u, std::min(32u, max_subgroup_size)); if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { @@ -803,12 +846,11 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? std::min(64u, max_kv_tile) : std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - decisions.wg_size = - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(std::max(1u, context.max_wg_size), - std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, - GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * std::max(1u, context.max_subgroup_size))) : - std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::min(std::max(1u, context.max_wg_size), + std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * max_subgroup_size)) : + std::max(max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); if (decisions.kv_tile == 0) { return decisions; @@ -817,9 +859,8 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( if (decisions.kv_direct) { GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::max(1u, context.max_subgroup_size) : - context.sg_mat_n; + decisions.kv_tile -= + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? min_subgroup_size : context.sg_mat_n; } } return decisions; @@ -2585,7 +2626,7 @@ class ggml_webgpu_shader_lib { const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); - ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; @@ -2676,9 +2717,10 @@ class ggml_webgpu_shader_lib { shader_src = wgsl_flash_attn_vec_split; } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { shader_src = wgsl_flash_attn_tile; - defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(key.min_subgroup_size) + "u"); + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(key.max_subgroup_size) + "u"); defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); - variant += "_tile"; + variant += "_tile_sg" + std::to_string(key.min_subgroup_size) + "_" + std::to_string(key.max_subgroup_size); } else { defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c9e42f37d7b..02414bfc8b6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -187,6 +187,7 @@ struct webgpu_capabilities { uint32_t sg_mat_k = 0; uint32_t subgroup_size = 0; + uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; size_t memset_bytes_per_thread; }; @@ -1442,6 +1443,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline @@ -1750,6 +1752,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); @@ -3469,6 +3472,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( @@ -3667,8 +3671,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #endif ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. - // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + // Runtime subgroup size can be any supported size in this range. Shaders + // that allocate per-lane register arrays must size them for the minimum. + ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize; ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; @@ -4031,6 +4036,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 4bd4ec670d4..ae8036b9ac5 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -25,6 +25,9 @@ enable subgroups; #define Q_TILE 4 #define KV_TILE 64 #define WG_SIZE 128 +#ifndef MIN_SUBGROUP_SIZE +#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE +#endif struct Params { offset_q: u32, @@ -116,8 +119,8 @@ struct Params { const FLOAT_MIN: f32 = -1.0e9; const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; const V_CHUNKS: u32 = HEAD_DIM_V / 4u; -const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; -const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; var q_shmem: array; var kv_shmem: array; From d2bd5ebdc55355baaeda39e8638865d28ffb5471 Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Mon, 11 May 2026 12:04:57 -0400 Subject: [PATCH 5/9] fix: pass editorconfig and address the type conflicts --- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 23 +++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 680c47045e4..36f02872939 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -50,9 +50,10 @@ struct Params { @group(0) @binding(PARAMS_BINDING) var params: Params; -fn erf_approx(x: f32) -> f32 { - let s = select(-1.0, 1.0, x >= 0.0); - let ax = abs(x); +fn erf_approx(x: TYPE) -> TYPE { + let x_f32 = f32(x); + let s = select(-1.0, 1.0, x_f32 >= 0.0); + let ax = abs(x_f32); let t = 1.0 / (1.0 + 0.3275911 * ax); @@ -61,13 +62,13 @@ fn erf_approx(x: f32) -> f32 { - 0.284496736) * t + 0.254829592) * t) * exp(-ax * ax); - return s * y; + return TYPE(s * y); } @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { - return; + return; } var i = gid.x; let ne2 = params.ne2; @@ -85,15 +86,13 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i1 = i / ne0; let i0 = i % ne0; - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; #ifdef ABS let res = abs(src[params.offset_src + src_idx]); #endif #ifdef SGN - let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), - src[params.offset_src + src_idx] > 0.0); + let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), src[params.offset_src + src_idx] > 0.0); #endif #ifdef NEG let res = -src[params.offset_src + src_idx]; @@ -108,8 +107,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); #endif #ifdef ELU - let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); #endif #ifdef HARDSIGMOID let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); @@ -134,8 +132,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = TYPE(params.fill_val); #endif #ifdef HARDSWISH - let res = src[params.offset_src + src_idx] * - min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); + let res = src[params.offset_src + src_idx] * min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); #endif #ifdef GELU let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(0.7978845608 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]))); From ab59cd5a5cda11dd7e53cb83743550873e046a7f Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Mon, 11 May 2026 12:10:43 -0400 Subject: [PATCH 6/9] fix: remove reduant pipeline keys --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 29 +------------------ 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c6a350526de..19880c70678 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -543,23 +543,12 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_sinks; bool uses_logit_softcap; uint32_t path; - uint32_t q_tile; - uint32_t kv_tile; - uint32_t wg_size; - uint32_t min_subgroup_size; - uint32_t max_subgroup_size; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && path == other.path && q_tile == other.q_tile && - kv_tile == other.kv_tile && wg_size == other.wg_size && min_subgroup_size == other.min_subgroup_size && - max_subgroup_size == other.max_subgroup_size && sg_mat_m == other.sg_mat_m && - sg_mat_n == other.sg_mat_n && sg_mat_k == other.sg_mat_k; + uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; @@ -577,14 +566,6 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); ggml_webgpu_hash_combine(seed, key.path); - ggml_webgpu_hash_combine(seed, key.q_tile); - ggml_webgpu_hash_combine(seed, key.kv_tile); - ggml_webgpu_hash_combine(seed, key.wg_size); - ggml_webgpu_hash_combine(seed, key.min_subgroup_size); - ggml_webgpu_hash_combine(seed, key.max_subgroup_size); - ggml_webgpu_hash_combine(seed, key.sg_mat_m); - ggml_webgpu_hash_combine(seed, key.sg_mat_n); - ggml_webgpu_hash_combine(seed, key.sg_mat_k); return seed; } }; @@ -660,14 +641,6 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.q_tile = decisions.q_tile; - key.kv_tile = decisions.kv_tile; - key.wg_size = decisions.wg_size; - key.min_subgroup_size = ggml_webgpu_effective_min_subgroup_size(context); - key.max_subgroup_size = ggml_webgpu_effective_max_subgroup_size(context); - key.sg_mat_m = context.sg_mat_m; - key.sg_mat_n = context.sg_mat_n; - key.sg_mat_k = context.sg_mat_k; key.path = decisions.path; return key; } From 5a6f9c633c08e40a71d1911dd1304f3c41e5590c Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Mon, 11 May 2026 13:00:55 -0400 Subject: [PATCH 7/9] fix: remove inline min/max group size functions and revert the flash attn path order --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 40 ++++++------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 19880c70678..932a01d385e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -582,20 +582,6 @@ struct ggml_webgpu_flash_attn_decisions { inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; -inline uint32_t ggml_webgpu_effective_min_subgroup_size(const ggml_webgpu_shader_lib_context & context) { - if (context.min_subgroup_size > 0) { - return context.min_subgroup_size; - } - return std::max(1u, context.max_subgroup_size); -} - -inline uint32_t ggml_webgpu_effective_max_subgroup_size(const ggml_webgpu_shader_lib_context & context) { - if (context.max_subgroup_size > 0) { - return context.max_subgroup_size; - } - return std::max(1u, context.min_subgroup_size); -} - inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { @@ -753,8 +739,6 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( const auto * V = context.src2; GGML_ASSERT(K != nullptr); GGML_ASSERT(V != nullptr); - const uint32_t min_subgroup_size = ggml_webgpu_effective_min_subgroup_size(context); - const uint32_t max_subgroup_size = ggml_webgpu_effective_max_subgroup_size(context); const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { constexpr uintptr_t ptr_base_addr = 0x1000u; @@ -775,15 +759,16 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = - max_subgroup_size > 0 && context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * max_subgroup_size; + context.max_subgroup_size > 0 && + context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows; + tile_can_dispatch_all_q_rows && !use_vec; - decisions.path = use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : GGML_WEBGPU_FLASH_ATTN_PATH_NONE; @@ -804,7 +789,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = 1u; decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, max_subgroup_size)); + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { @@ -822,8 +807,8 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? std::min(std::max(1u, context.max_wg_size), std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, - GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * max_subgroup_size)) : - std::max(max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) : + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); if (decisions.kv_tile == 0) { return decisions; @@ -833,7 +818,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { decisions.kv_tile -= - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? min_subgroup_size : context.sg_mat_n; + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n; } } return decisions; @@ -2690,10 +2675,11 @@ class ggml_webgpu_shader_lib { shader_src = wgsl_flash_attn_vec_split; } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { shader_src = wgsl_flash_attn_tile; - defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(key.min_subgroup_size) + "u"); - defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(key.max_subgroup_size) + "u"); + defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); - variant += "_tile_sg" + std::to_string(key.min_subgroup_size) + "_" + std::to_string(key.max_subgroup_size); + variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + + std::to_string(context.max_subgroup_size); } else { defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); From e115d5446bc8a8f626100e74b7efef672a408477 Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Mon, 11 May 2026 18:34:53 -0400 Subject: [PATCH 8/9] fix: use clamp to avoid NaN for GELU --- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 36f02872939..8d176b32d56 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -135,13 +135,13 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] * min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); #endif #ifdef GELU - let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(0.7978845608 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]))); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(clamp(0.7978845608028654 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), -9.010913, 9.010913))); #endif #ifdef GELU_QUICK - let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(-1.702 * src[params.offset_src + src_idx]))); + let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -9.010913, 9.010913)))); #endif #ifdef GELU_ERF - let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.70710678)); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.7071067811865476)); #endif #ifdef XIELU let val = f32(src[params.offset_src + src_idx]); From 6797d633f35e94d4449499a9808eff033a631c5d Mon Sep 17 00:00:00 2001 From: Constannnnnt Date: Mon, 11 May 2026 18:44:49 -0400 Subject: [PATCH 9/9] fix: use the right range for exp, 80 is safer for f32 exp --- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8d176b32d56..8e34e1c9ca0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -138,7 +138,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(clamp(0.7978845608028654 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), -9.010913, 9.010913))); #endif #ifdef GELU_QUICK - let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -9.010913, 9.010913)))); + let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -80.0, 80.0)))); #endif #ifdef GELU_ERF let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.7071067811865476));