From e1a0bb34da78450fb6eeb653151014a9cfe53488 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 27 May 2026 13:52:40 -0700 Subject: [PATCH 1/7] Start work on flash_attn refactor --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 236 +++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 379 +++++++++--------- 2 files changed, 343 insertions(+), 272 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 60e98a607410..3e72093b11d5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -18,6 +18,9 @@ #define GGML_WEBGPU_F32_SIZE_BYTES 4 #define GGML_WEBGPU_I32_SIZE_BYTES 4 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u +#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. #define GGML_WEBGPU_KV_SEQ_PAD 256u @@ -618,17 +621,35 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_ } } +inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; +} + +inline bool ggml_webgpu_flash_attn_f16_vec4_aligned(const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + const size_t alignment = std::max(1u, storage_offset_alignment); + const uint32_t k_offset_elems = + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + return (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && + (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); +} + inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( const ggml_webgpu_shader_lib_context & context, const ggml_webgpu_flash_attn_decisions & decisions) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; - bool kv_direct = false; + bool kv_direct = false; + uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + kv_direct_align = context.sg_mat_k; + } if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { - kv_direct_align = context.sg_mat_k; - } kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); @@ -684,29 +705,18 @@ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { } }; -// This is exposed because it's necessary in supports_op +// Note: this will slightly overestimate memory usage for vec path +// since row_max and exp_sum shmem are not needed. inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, - bool kv_direct, - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + bool kv_direct) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; - if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - f32_elems += head_dim_qk; // q_shmem - if (!kv_direct) { - f32_elems += kv_tile * max_head_dim; // kv_shmem - } - f32_elems += head_dim_v; // o_shmem - if (has_mask) { - f32_elems += kv_tile; // mask_shmem - } - f32_elems += kv_tile; // inter_shmem - return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; - } + f32_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { f32_elems += kv_tile * max_head_dim; // kv_shmem @@ -721,25 +731,27 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_pipeline_key & key) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_granularity = std::max(1u, context.sg_mat_n); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; - kv_granularity = 1u; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - q_tile = 1u; - kv_granularity = 8u; - } - const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, + const ggml_tensor * K, + uint32_t kv_direct_align) { + return K->type == GGML_TYPE_F16 && (Q->ne[0] % std::max(1u, kv_direct_align) == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); +} + +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, + uint32_t q_tile, + uint32_t kv_granularity, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const size_t base_q_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct); if (limit_bytes <= base_q_bytes) { return 0; } - const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); + const size_t one_kv_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct); const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; if (bytes_per_kv == 0) { return 0; @@ -748,87 +760,118 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } -inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( +inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_vec_decisions( const ggml_webgpu_shader_lib_context & context, size_t storage_offset_alignment) { ggml_webgpu_flash_attn_decisions decisions = {}; - const size_t alignment = std::max(1u, storage_offset_alignment); - const auto * K = context.src1; - const auto * V = context.src2; + const auto * K = context.src1; + const auto * V = context.src2; GGML_ASSERT(K != nullptr); GGML_ASSERT(V != nullptr); - const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { - constexpr uintptr_t ptr_base_addr = 0x1000u; - const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; - return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; - }; - - const uint32_t k_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && - (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); + const bool f16_vec4_aligned = ggml_webgpu_flash_attn_f16_vec4_aligned(K, V, storage_offset_alignment); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0; - // Compile with enough invocations to cover the largest reported subgroup. - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && + const bool use_vec = context.supports_subgroups && + (context.src0->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && kv_vec_head_dims_aligned && kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); - const bool tile_can_dispatch_all_q_rows = - context.max_subgroup_size > 0 && - context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_subgroup_matrix = - context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && - context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; - const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && - V->type == GGML_TYPE_F16 && f16_vec4_aligned && - (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows && !use_vec; - - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + if (!use_vec) { return decisions; } + decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_VEC; + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); decisions.kv_direct = key.kv_direct; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - // invalidate if even the smallest kv_tile doesn't fit in shared memory + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + context.wg_mem_limit_bytes, 1u, 1u, key.head_dim_qk, key.head_dim_v, key.has_mask, key.kv_direct); if (max_kv_tile == 0) { decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; return decisions; } - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - decisions.q_tile = 1u; - decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); - decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = context.max_subgroup_size; - if (decisions.kv_direct) { - decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= 8u; - } + decisions.q_tile = 1u; + decisions.kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); + decisions.wg_size = context.max_subgroup_size; + if (decisions.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= 1u; } + } + return decisions; +} + +inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, + uint32_t sg_mat_k, + uint32_t sg_mat_n, + const ggml_tensor * Q, + const ggml_tensor * V) { + return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0; +} + +inline bool ggml_webgpu_flash_attn_can_use_tile_path(bool supports_subgroups, + uint32_t max_wg_size, + uint32_t max_subgroup_size, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + GGML_ASSERT(Q != nullptr); + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const bool f16_vec4_aligned = ggml_webgpu_flash_attn_f16_vec4_aligned(K, V, storage_offset_alignment); + const bool tile_can_dispatch_all_q_rows = + max_subgroup_size > 0 && + max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * max_subgroup_size; + + return supports_subgroups && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && + (Q->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + (V->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + tile_can_dispatch_all_q_rows; +} + +inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_non_vec_decisions( + const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + ggml_webgpu_flash_attn_decisions decisions = {}; + const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); + const bool can_use_tile = + !can_use_subgroup_matrix && ggml_webgpu_flash_attn_can_use_tile_path( + context.supports_subgroups, context.max_wg_size, context.max_subgroup_size, context.src0, context.src1, + context.src2, storage_offset_alignment); + + decisions.path = can_use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + can_use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { return decisions; } decisions.q_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); + decisions.kv_direct = key.kv_direct; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + context.wg_mem_limit_bytes, decisions.q_tile, + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? 1u : std::max(1u, context.sg_mat_n), key.head_dim_qk, + key.head_dim_v, key.has_mask, key.kv_direct); + if (max_kv_tile == 0) { + decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + return decisions; + } + decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(64u, max_kv_tile) : + std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile) : std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? std::min(std::max(1u, context.max_wg_size), @@ -2644,10 +2687,8 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - const ggml_webgpu_flash_attn_decisions decisions = - ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + webgpu_pipeline get_flash_attn_pipeline_for_decisions(const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_decisions & decisions) { GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); auto it = flash_attn_pipelines.find(key); @@ -2764,6 +2805,23 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_flash_attn_non_vec_pipeline(const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + const ggml_webgpu_flash_attn_decisions decisions = + ggml_webgpu_flash_attn_get_non_vec_decisions(context, storage_offset_alignment); + GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); + GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC); + return get_flash_attn_pipeline_for_decisions(context, decisions); + } + + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + const ggml_webgpu_flash_attn_decisions decisions = + ggml_webgpu_flash_attn_get_vec_decisions(context, storage_offset_alignment); + GGML_ASSERT(decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC); + return get_flash_attn_pipeline_for_decisions(context, decisions); + } + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { ggml_webgpu_flash_attn_blk_pipeline_key key = {}; key.kv_tile = kv_tile; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f6d17a073bee..a940484928e1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1759,13 +1759,45 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +struct ggml_webgpu_flash_attn_op { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + std::vector params; + std::vector entries; + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + bool has_mask = false; + bool has_sinks = false; + bool kv_overlap = false; +}; + +static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + GGML_ASSERT(Q != nullptr); + + const bool f16_vec4_aligned = + ggml_webgpu_flash_attn_f16_vec4_aligned(K, V, global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const uint32_t kv_vec_head_align = + K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = Q->ne[0] % kv_vec_head_align == 0 && V->ne[0] % kv_vec_head_align == 0; + + return global_ctx->capabilities.supports_subgroups && + (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && kv_vec_head_dims_aligned && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); +} + +static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = ggml_get_op_params_f32(dst, 0); float max_bias = ggml_get_op_params_f32(dst, 1); float logit_softcap = ggml_get_op_params_f32(dst, 2); @@ -1776,47 +1808,43 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = Q; - shader_lib_ctx.src1 = K; - shader_lib_ctx.src2 = V; - shader_lib_ctx.src3 = mask; - shader_lib_ctx.src4 = sinks; - shader_lib_ctx.dst = dst; - shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( - shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast(pipeline.context.get()); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); - const bool kv_overlap = decisions->kv_overlap; - - uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - size_t kv_bind_offset = 0; - size_t kv_bind_size = 0; - if (kv_overlap) { + ggml_webgpu_flash_attn_op op = {}; + op.shader_lib_ctx.src0 = Q; + op.shader_lib_ctx.src1 = K; + op.shader_lib_ctx.src2 = V; + op.shader_lib_ctx.src3 = mask; + op.shader_lib_ctx.src4 = sinks; + op.shader_lib_ctx.dst = dst; + op.shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + op.shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + op.shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + op.shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + op.shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + op.shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + op.shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + op.shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; + op.shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + + op.has_mask = mask != nullptr; + op.has_sinks = sinks != nullptr; + op.kv_overlap = ggml_webgpu_tensor_overlap(K, V); + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + if (op.kv_overlap) { const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); - kv_bind_offset = merged_range.offset; - kv_bind_size = merged_range.size; + op.kv_bind_offset = merged_range.offset; + op.kv_bind_size = merged_range.size; offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); } - std::vector params = { + op.params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), offset_k, offset_v, - has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, - has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, + op.has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + op.has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) Q->ne[2], // number of heads (uint32_t) Q->ne[1], // sequence length (Q) @@ -1830,7 +1858,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + op.has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) ggml_webgpu_u32_from_f32(max_bias), @@ -1838,32 +1866,49 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_u32_from_f32(n_head_log2), ggml_webgpu_u32_from_f32(m0), ggml_webgpu_u32_from_f32(m1) - }; - std::vector entries = { + op.entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), }; - if (kv_overlap) { - entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + if (op.kv_overlap) { + op.entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); } - uint32_t binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); + uint32_t binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } - if (has_sinks) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + if (op.has_sinks) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); } - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); - } + return op; +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_non_vec_pipeline( + op.shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; + return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst, + ggml_webgpu_flash_attn_op op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline( + op.shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1914,7 +1959,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, webgpu_pipeline blk_pipeline; std::vector blk_params; std::vector blk_entries; - if (has_mask) { + if (op.has_mask) { blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); blk_nblk1 = (uint32_t) Q->ne[1]; blk_buf = ggml_webgpu_tensor_buf(dst); @@ -1922,7 +1967,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + const ggml_webgpu_shader_lib_context blk_shader_ctx = op.shader_lib_ctx; blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); blk_params = { @@ -1942,8 +1987,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); } - std::vector split_params = params; - if (has_mask) { + std::vector split_params = op.params; + if (op.has_mask) { split_params.push_back(0u); // blk_base split_params.push_back(blk_nblk0); // blk_nblk0 split_params.push_back(blk_nblk1); // blk_nblk1 @@ -1956,9 +2001,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), ggml_webgpu_tensor_binding_size(ctx, Q)), }; - if (kv_overlap) { + if (op.kv_overlap) { split_entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), @@ -1967,18 +2012,18 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_tensor_align_offset(ctx, V), ggml_webgpu_tensor_binding_size(ctx, V))); } - uint32_t split_binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { + uint32_t split_binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), ggml_webgpu_tensor_align_offset(ctx, mask), ggml_webgpu_tensor_binding_size(ctx, mask))); } - if (has_sinks) { + if (op.has_sinks) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), ggml_webgpu_tensor_align_offset(ctx, sinks), ggml_webgpu_tensor_binding_size(ctx, sinks))); } - if (has_mask) { + if (op.has_mask) { split_entries.push_back( ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); } @@ -1997,7 +2042,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, reduce_sg_size, (uint32_t) std::min((uint64_t) nwg * reduce_sg_size, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + ggml_webgpu_shader_lib_context reduce_shader_ctx = op.shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -2024,7 +2069,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector dispatches; - if (has_mask) { + if (op.has_mask) { dispatches.push_back({ blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } }); @@ -2041,6 +2086,20 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + ggml_webgpu_flash_attn_op op = ggml_webgpu_flash_attn_prepare(ctx, Q, K, V, mask, sinks, dst); + if (ggml_webgpu_flash_attn_use_vec_path(ctx->global_ctx, Q, K, V)) { + return ggml_webgpu_flash_attn_vec(ctx, Q, K, V, mask, sinks, dst, std::move(op)); + } + return ggml_webgpu_flash_attn_direct(ctx, op); +} + static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -3561,66 +3620,52 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * K = tensor->src[1]; const ggml_tensor * V = tensor->src[2]; const ggml_tensor * mask = tensor->src[3]; - const ggml_tensor * sinks = tensor->src[4]; - if (Q && K && V) { - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = const_cast(Q); - shader_lib_ctx.src1 = const_cast(K); - shader_lib_ctx.src2 = const_cast(V); - shader_lib_ctx.src3 = const_cast(mask); - shader_lib_ctx.src4 = const_cast(sinks); - shader_lib_ctx.dst = const_cast(tensor); - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const uint32_t kv_tile = decisions.kv_tile; - - const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { + const bool kv_direct = + ggml_webgpu_flash_attn_kv_direct(Q, K, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, 1u, 1u, (uint32_t) Q->ne[0], + (uint32_t) V->ne[0], mask != nullptr, kv_direct); + GGML_ASSERT(max_kv_tile > 0); + uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); + if (kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= 1u; } - nwg = std::min(nwg, vec_nwg_cap); - - const size_t align = - ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - if (nwg > 1u) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( - (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += tmp_size_bytes + align; - } else { - res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; - } - if (mask != nullptr) { - const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); - const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - const size_t blk_size_bytes = - ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += blk_size_bytes + align; - } - res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); } + + const uint32_t vec_nwg_cap = capabilities.min_subgroup_size; + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + + const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2( + (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); } } break; @@ -4147,6 +4192,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { + // conservative support checks for whether the more resource-intensive shader paths + // can be used, to avoid cases where flash_attn is assigned to the CPU later on supports_op = src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && @@ -4154,63 +4201,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const if (!supports_op) { break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = src0; - shader_lib_ctx.src1 = src1; - shader_lib_ctx.src2 = src2; - shader_lib_ctx.src3 = op->src[3]; - shader_lib_ctx.src4 = op->src[4]; - shader_lib_ctx.dst = const_cast(op); - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; + const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); + const bool use_tile = + !use_subgroup_matrix && + ggml_webgpu_flash_attn_can_use_tile_path( + capabilities.supports_subgroups, capabilities.limits.maxComputeInvocationsPerWorkgroup, + capabilities.max_subgroup_size, src0, src1, src2, storage_offset_alignment); + if (!use_subgroup_matrix && !use_tile) { supports_op = false; break; } - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - supports_op = false; - break; - } - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } + const uint32_t q_tile = + use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, capabilities.sg_mat_k) : + false; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); + supports_op = max_kv_tile > 0; break; } case GGML_OP_RMS_NORM: From 8d61b5cdd68483cd54286ef975b6da85ef3faaf4 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 27 May 2026 15:11:11 -0700 Subject: [PATCH 2/7] Refactor --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 593 ++++++++---------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 58 +- 2 files changed, 304 insertions(+), 347 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3e72093b11d5..755cfed80a85 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -545,14 +545,7 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ -enum ggml_webgpu_flash_attn_path : uint32_t { - GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, - GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, - GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, -}; - -struct ggml_webgpu_flash_attn_pipeline_key { +struct ggml_webgpu_flash_attn_common_pipeline_key { ggml_type q_type; ggml_type kv_type; ggml_type dst_type; @@ -563,64 +556,78 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - uint32_t path; - bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { + bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const { return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && path == other.path; + uses_logit_softcap == other.uses_logit_softcap; + } +}; + +inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, + const ggml_webgpu_flash_attn_common_pipeline_key & key) { + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); +} + +struct ggml_webgpu_flash_attn_vec_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + + bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { + return common == other.common; + } +}; + +struct ggml_webgpu_flash_attn_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + bool use_sg_matrix; + + bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { + return common == other.common && use_sg_matrix == other.use_sg_matrix; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_type); - ggml_webgpu_hash_combine(seed, key.kv_type); - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.head_dim_qk); - ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.kv_direct); - ggml_webgpu_hash_combine(seed, key.kv_overlap); - ggml_webgpu_hash_combine(seed, key.has_mask); - ggml_webgpu_hash_combine(seed, key.has_sinks); - ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.path); + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + ggml_webgpu_hash_combine(seed, key.use_sg_matrix); return seed; } }; +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; - bool kv_direct = false; - bool kv_overlap = false; + bool use_sg_matrix = false; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; }; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || - key.head_dim_qk != key.head_dim_v) { - return 1u; - } - - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - return 2u; - case 96: - return 4u; - default: - return 1u; - } -} - inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { constexpr uintptr_t ptr_base_addr = 0x1000u; const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; @@ -639,37 +646,114 @@ inline bool ggml_webgpu_flash_attn_f16_vec4_aligned(const ggml_tensor * K, (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); } -inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_decisions & decisions) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - bool kv_direct = false; - uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { - kv_direct_align = context.sg_mat_k; - } - if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - kv_direct = (context.src1->type == GGML_TYPE_F16) && - (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && - (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - } - - ggml_webgpu_flash_attn_pipeline_key key = {}; - key.q_type = context.src0->type; - key.kv_type = context.src1->type; - key.dst_type = context.dst->type; - key.head_dim_qk = (uint32_t) context.src0->ne[0]; - key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = kv_direct; - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = has_mask; - key.has_sinks = has_sinks; - key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.path = decisions.path; +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, + const ggml_tensor * K, + uint32_t kv_direct_align) { + return K->type == GGML_TYPE_F16 && (Q->ne[0] % std::max(1u, kv_direct_align) == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); +} + +inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key( + const ggml_webgpu_shader_lib_context & context, + uint32_t kv_direct_align) { + ggml_webgpu_flash_attn_common_pipeline_key key = {}; + key.q_type = context.src0->type; + key.kv_type = context.src1->type; + key.dst_type = context.dst->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; return key; } +inline std::vector ggml_webgpu_flash_attn_common_defines( + const ggml_webgpu_flash_attn_common_pipeline_key & key, + std::string & variant, + uint32_t q_tile, + uint32_t kv_tile, + uint32_t wg_size) { + std::vector defines; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; + default: + GGML_ABORT("Unsupported Q type for flash attention shader"); + } + variant += std::string("_q") + ggml_type_name(key.q_type); + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + return defines; +} + struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { uint32_t head_dim_v; uint32_t wg_size; @@ -731,12 +815,7 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, - const ggml_tensor * K, - uint32_t kv_direct_align) { - return K->type == GGML_TYPE_F16 && (Q->ne[0] % std::max(1u, kv_direct_align) == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); -} + inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, uint32_t q_tile, @@ -760,52 +839,24 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } -inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_vec_decisions( - const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - ggml_webgpu_flash_attn_decisions decisions = {}; - const auto * K = context.src1; - const auto * V = context.src2; - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - - const bool f16_vec4_aligned = ggml_webgpu_flash_attn_f16_vec4_aligned(K, V, storage_offset_alignment); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && - context.src2->ne[0] % kv_vec_head_align == 0; - const bool use_vec = context.supports_subgroups && - (context.src0->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && - kv_vec_head_dims_aligned && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); - if (!use_vec) { - return decisions; - } - - decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_VEC; - - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - decisions.kv_direct = key.kv_direct; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( - context.wg_mem_limit_bytes, 1u, 1u, key.head_dim_qk, key.head_dim_v, key.has_mask, key.kv_direct); - if (max_kv_tile == 0) { - decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - return decisions; - } - - decisions.q_tile = 1u; - decisions.kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); - decisions.wg_size = context.max_subgroup_size; - if (decisions.kv_direct) { - decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= 1u; - } - } - return decisions; +inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_kv_tile = + ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); + if (kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= 1u; + } + } + + return kv_tile; } inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, @@ -816,83 +867,6 @@ inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0; } -inline bool ggml_webgpu_flash_attn_can_use_tile_path(bool supports_subgroups, - uint32_t max_wg_size, - uint32_t max_subgroup_size, - const ggml_tensor * Q, - const ggml_tensor * K, - const ggml_tensor * V, - size_t storage_offset_alignment) { - GGML_ASSERT(Q != nullptr); - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - - const bool f16_vec4_aligned = ggml_webgpu_flash_attn_f16_vec4_aligned(K, V, storage_offset_alignment); - const bool tile_can_dispatch_all_q_rows = - max_subgroup_size > 0 && - max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * max_subgroup_size; - - return supports_subgroups && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && - (Q->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (V->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows; -} - -inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_non_vec_decisions( - const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - ggml_webgpu_flash_attn_decisions decisions = {}; - const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( - context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); - const bool can_use_tile = - !can_use_subgroup_matrix && ggml_webgpu_flash_attn_can_use_tile_path( - context.supports_subgroups, context.max_wg_size, context.max_subgroup_size, context.src0, context.src1, - context.src2, storage_offset_alignment); - - decisions.path = can_use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - can_use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { - return decisions; - } - - decisions.q_tile = - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - decisions.kv_direct = key.kv_direct; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( - context.wg_mem_limit_bytes, decisions.q_tile, - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? 1u : std::max(1u, context.sg_mat_n), key.head_dim_qk, - key.head_dim_v, key.has_mask, key.kv_direct); - if (max_kv_tile == 0) { - decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - return decisions; - } - - decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile) : - std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(std::max(1u, context.max_wg_size), - std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, - GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) : - std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - if (decisions.kv_tile == 0) { - return decisions; - } - - if (decisions.kv_direct) { - GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n; - } - } - return decisions; -} - /** Matrix Multiplication **/ struct ggml_webgpu_mul_mat_vec_pipeline_key { @@ -1163,7 +1137,13 @@ class ggml_webgpu_shader_lib { concat_pipelines; // type std::unordered_map repeat_pipelines; // type - std::unordered_map + std::unordered_map + flash_attn_vec_pipelines; + std::unordered_map flash_attn_pipelines; std::unordered_mapsecond; - } - std::vector defines; - std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : - "flash_attn"; - - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - switch (key.q_type) { - case GGML_TYPE_F32: - defines.push_back("Q_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("Q_F16"); - break; - default: - GGML_ABORT("Unsupported Q type for flash attention shader"); - } - variant += std::string("_q") + ggml_type_name(key.q_type); - - switch (key.dst_type) { - case GGML_TYPE_F32: - defines.push_back("DST_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("DST_F16"); - break; - default: - GGML_ABORT("Unsupported dst type for flash attention shader"); - } - variant += std::string("_dst") + ggml_type_name(key.dst_type); - - if (key.has_mask) { - defines.push_back("MASK"); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("BLK"); - variant += "_mask_blk"; - } else { - variant += "_mask"; + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); + ggml_webgpu_flash_attn_decisions decisions = {}; + decisions.use_sg_matrix = can_use_subgroup_matrix; + decisions.q_tile = + decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.common = ggml_webgpu_flash_attn_make_common_pipeline_key( + context, decisions.use_sg_matrix ? context.sg_mat_k : 1u); + key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct; + key.use_sg_matrix = decisions.use_sg_matrix; + + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + context.wg_mem_limit_bytes, decisions.q_tile, + decisions.use_sg_matrix ? context.sg_mat_n : 1u, key.common.head_dim_qk, + key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + decisions.kv_tile = decisions.use_sg_matrix ? + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) : + std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile); + decisions.wg_size = decisions.use_sg_matrix ? + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) : + std::min(context.max_wg_size, + std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)); + + if (key.common.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size; } } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - if (key.kv_overlap) { - defines.push_back("KV_OVERLAP"); - variant += "_kv_overlap"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } - const char * shader_src = wgsl_flash_attn; - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("KV_GRANULARITY=8"); - defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); - shader_src = wgsl_flash_attn_vec_split; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile"; + std::vector defines = + ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile, decisions.kv_tile, + decisions.wg_size); + const char * shader_src = nullptr; + if (!key.use_sg_matrix) { shader_src = wgsl_flash_attn_tile; defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); - defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); + defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.common.head_dim_qk, key.common.head_dim_v))); variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + std::to_string(context.max_subgroup_size); } else { + shader_src = wgsl_flash_attn; defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - - auto pipeline_decisions = std::make_shared(decisions); - pipeline_decisions->kv_overlap = key.kv_overlap; - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); - + auto pipeline_decisions = std::make_shared(decisions); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - pipeline.context = pipeline_decisions; + pipeline.context = pipeline_decisions; flash_attn_pipelines[key] = pipeline; return flash_attn_pipelines[key]; } - webgpu_pipeline get_flash_attn_non_vec_pipeline(const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - const ggml_webgpu_flash_attn_decisions decisions = - ggml_webgpu_flash_attn_get_non_vec_decisions(context, storage_offset_alignment); - GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); - GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC); - return get_flash_attn_pipeline_for_decisions(context, decisions); - } + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_pipeline_key key = {}; + key.common = ggml_webgpu_flash_attn_make_common_pipeline_key( + context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); - webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - const ggml_webgpu_flash_attn_decisions decisions = - ggml_webgpu_flash_attn_get_vec_decisions(context, storage_offset_alignment); - GGML_ASSERT(decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC); - return get_flash_attn_pipeline_for_decisions(context, decisions); + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; + } + + ggml_webgpu_flash_attn_vec_decisions decisions = {}; + decisions.kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( + context.wg_mem_limit_bytes, key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, + key.common.kv_direct); + decisions.wg_size = context.max_subgroup_size; + + std::string variant = "flash_attn_vec"; + std::vector defines = + ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size); + if (key.common.has_mask) { + defines.push_back("BLK"); + variant.resize(variant.size() - (sizeof("_mask") - 1)); + variant += "_mask_blk"; + } + uint32_t vec_ne = 1u; + if (key.common.kv_type == GGML_TYPE_F16 && key.common.head_dim_qk == key.common.head_dim_v) { + switch (key.common.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + + auto pipeline_decisions = std::make_shared(decisions); + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = pipeline_decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; } webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a940484928e1..2ca98760fdb8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1889,9 +1889,17 @@ static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & return op; } +static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv_tile, uint32_t seq_len_kv) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) kv_tile; + while ((2u * nwg * kv_span) < (uint64_t) seq_len_kv && nwg < vec_nwg_cap) { + nwg <<= 1; + } + return std::min(nwg, vec_nwg_cap); +} + static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) { - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_non_vec_pipeline( - op.shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; @@ -1906,9 +1914,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ct ggml_tensor * sinks, ggml_tensor * dst, ggml_webgpu_flash_attn_op op) { - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline( - op.shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast(pipeline.context.get()); + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1917,12 +1924,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ct uint32_t blk_batch_count = 0; const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]); const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; const bool use_vec_reduce = nwg > 1u; GGML_ASSERT(nrows <= UINT32_MAX); @@ -3624,25 +3626,12 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { const bool kv_direct = ggml_webgpu_flash_attn_kv_direct(Q, K, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( - capabilities.limits.maxComputeWorkgroupStorageSize, 1u, 1u, (uint32_t) Q->ne[0], - (uint32_t) V->ne[0], mask != nullptr, kv_direct); - GGML_ASSERT(max_kv_tile > 0); - uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); - if (kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= 1u; - } - } + const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0], + mask != nullptr, kv_direct); const uint32_t vec_nwg_cap = capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]); const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; @@ -4205,11 +4194,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); + const bool f16_vec4_aligned = + ggml_webgpu_flash_attn_f16_vec4_aligned(src1, src2, storage_offset_alignment); + const bool tile_can_dispatch_all_q_rows = + capabilities.limits.maxComputeInvocationsPerWorkgroup >= + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; const bool use_tile = !use_subgroup_matrix && - ggml_webgpu_flash_attn_can_use_tile_path( - capabilities.supports_subgroups, capabilities.limits.maxComputeInvocationsPerWorkgroup, - capabilities.max_subgroup_size, src0, src1, src2, storage_offset_alignment); + capabilities.supports_subgroups && src1->type == GGML_TYPE_F16 && + src2->type == GGML_TYPE_F16 && f16_vec4_aligned && + (src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + (src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + tile_can_dispatch_all_q_rows; if (!use_subgroup_matrix && !use_tile) { supports_op = false; break; From 19f150a5b88fafe5dbefb2b7dfb7afb51c5088ad Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 27 May 2026 15:46:43 -0700 Subject: [PATCH 3/7] Split k/v quantization --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 62 +++++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 89 ++++++++------ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 111 +++++++++++------- .../wgsl-shaders/flash_attn_tile.wgsl | 48 ++++---- .../wgsl-shaders/flash_attn_vec_split.wgsl | 107 ++++++++++------- 5 files changed, 253 insertions(+), 164 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 755cfed80a85..971d5ad5439c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -547,7 +547,8 @@ struct ggml_webgpu_unary_pipeline_key_hash { struct ggml_webgpu_flash_attn_common_pipeline_key { ggml_type q_type; - ggml_type kv_type; + ggml_type k_type; + ggml_type v_type; ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; @@ -558,7 +559,8 @@ struct ggml_webgpu_flash_attn_common_pipeline_key { bool uses_logit_softcap; bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const { - return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && + return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type && + dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; @@ -568,7 +570,8 @@ struct ggml_webgpu_flash_attn_common_pipeline_key { inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, const ggml_webgpu_flash_attn_common_pipeline_key & key) { ggml_webgpu_hash_combine(seed, key.q_type); - ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.k_type); + ggml_webgpu_hash_combine(seed, key.v_type); ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); @@ -635,15 +638,18 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { } inline bool ggml_webgpu_flash_attn_f16_vec4_aligned(const ggml_tensor * K, - const ggml_tensor * V, size_t storage_offset_alignment) { const size_t alignment = std::max(1u, storage_offset_alignment); - const uint32_t k_offset_elems = + const uint32_t offset_elems = (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - return (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && - (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); + return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; +} + +inline bool ggml_webgpu_flash_attn_f16_vec4_aligned(const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + return ggml_webgpu_flash_attn_f16_vec4_aligned(K, storage_offset_alignment) && + ggml_webgpu_flash_attn_f16_vec4_aligned(V, storage_offset_alignment); } inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, @@ -658,7 +664,8 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co uint32_t kv_direct_align) { ggml_webgpu_flash_attn_common_pipeline_key key = {}; key.q_type = context.src0->type; - key.kv_type = context.src1->type; + key.k_type = context.src1->type; + key.v_type = context.src2->type; key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; @@ -678,23 +685,41 @@ inline std::vector ggml_webgpu_flash_attn_common_defines( uint32_t wg_size) { std::vector defines; - switch (key.kv_type) { + switch (key.k_type) { + case GGML_TYPE_F32: + defines.push_back("K_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("K_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("K_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("K_Q8_0"); + break; + default: + GGML_ABORT("Unsupported K type for flash attention shader"); + } + variant += std::string("_k") + ggml_type_name(key.k_type); + + switch (key.v_type) { case GGML_TYPE_F32: - defines.push_back("KV_F32"); + defines.push_back("V_F32"); break; case GGML_TYPE_F16: - defines.push_back("KV_F16"); + defines.push_back("V_F16"); break; case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); + defines.push_back("V_Q4_0"); break; case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); + defines.push_back("V_Q8_0"); break; default: - GGML_ABORT("Unsupported KV type for flash attention shader"); + GGML_ABORT("Unsupported V type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(key.kv_type); + variant += std::string("_v") + ggml_type_name(key.v_type); switch (key.q_type) { case GGML_TYPE_F32: @@ -2759,7 +2784,8 @@ class ggml_webgpu_shader_lib { variant += "_mask_blk"; } uint32_t vec_ne = 1u; - if (key.common.kv_type == GGML_TYPE_F16 && key.common.head_dim_qk == key.common.head_dim_v) { + if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 && + key.common.head_dim_qk == key.common.head_dim_v) { switch (key.common.head_dim_qk) { case 64: case 192: diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 2ca98760fdb8..a47f1fd71f0c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1778,17 +1778,24 @@ static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & gl GGML_ASSERT(V != nullptr); GGML_ASSERT(Q != nullptr); - const bool f16_vec4_aligned = - ggml_webgpu_flash_attn_f16_vec4_aligned(K, V, global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const bool kv_vec_type_supported = + const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const bool k_f16_vec4_aligned = + K->type != GGML_TYPE_F16 || ggml_webgpu_flash_attn_f16_vec4_aligned(K, storage_offset_alignment); + const bool v_f16_vec4_aligned = + V->type != GGML_TYPE_F16 || ggml_webgpu_flash_attn_f16_vec4_aligned(V, storage_offset_alignment); + const bool k_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = + const bool v_vec_type_supported = + V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; + const uint32_t k_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = Q->ne[0] % kv_vec_head_align == 0 && V->ne[0] % kv_vec_head_align == 0; + const uint32_t v_vec_head_align = + V->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(V->type); + const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; - return global_ctx->capabilities.supports_subgroups && - (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && kv_vec_head_dims_aligned && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); + return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && + kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_f16_vec4_aligned && + v_f16_vec4_aligned; } static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx, @@ -1899,10 +1906,10 @@ static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv } static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) { - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x); } @@ -1914,8 +1921,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ct ggml_tensor * sinks, ggml_tensor * dst, ggml_webgpu_flash_attn_op op) { - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1925,7 +1932,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ct const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; const bool use_vec_reduce = nwg > 1u; GGML_ASSERT(nrows <= UINT32_MAX); @@ -3618,11 +3625,11 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer break; case GGML_OP_FLASH_ATTN_EXT: { - const ggml_tensor * Q = tensor->src[0]; - const ggml_tensor * K = tensor->src[1]; - const ggml_tensor * V = tensor->src[2]; - const ggml_tensor * mask = tensor->src[3]; - const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { const bool kv_direct = ggml_webgpu_flash_attn_kv_direct(Q, K, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); @@ -3631,15 +3638,15 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer mask != nullptr, kv_direct); const uint32_t vec_nwg_cap = capabilities.min_subgroup_size; - uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]); + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]); - const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; + const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; if (nwg > 1u) { const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( - (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t tmp_size_bytes = ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), + WEBGPU_STORAGE_BUF_BINDING_MULT); res += tmp_size_bytes + align; } else { res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; @@ -4186,26 +4193,31 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; + (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16 || + src2->type == GGML_TYPE_Q4_0 || src2->type == GGML_TYPE_Q8_0) && + op->type == GGML_TYPE_F32; if (!supports_op) { break; } + if (ggml_webgpu_tensor_overlap(src1, src2) && src1->type != src2->type && + !ggml_is_quantized(src1->type) && !ggml_is_quantized(src2->type)) { + supports_op = false; + break; + } const auto & capabilities = ctx->webgpu_global_ctx->capabilities; const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; - const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); - const bool f16_vec4_aligned = + const bool f16_vec4_aligned = ggml_webgpu_flash_attn_f16_vec4_aligned(src1, src2, storage_offset_alignment); - const bool tile_can_dispatch_all_q_rows = + const bool tile_can_dispatch_all_q_rows = capabilities.limits.maxComputeInvocationsPerWorkgroup >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; - const bool use_tile = - !use_subgroup_matrix && - capabilities.supports_subgroups && src1->type == GGML_TYPE_F16 && - src2->type == GGML_TYPE_F16 && f16_vec4_aligned && - (src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows; + const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && + src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && f16_vec4_aligned && + (src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + (src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + tile_can_dispatch_all_q_rows; if (!use_subgroup_matrix && !use_tile) { supports_op = false; break; @@ -4213,10 +4225,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const const uint32_t q_tile = use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; - const bool kv_direct = use_subgroup_matrix ? - ggml_webgpu_flash_attn_kv_direct(src0, src1, capabilities.sg_mat_k) : - false; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + const bool kv_direct = + use_subgroup_matrix ? ggml_webgpu_flash_attn_kv_direct(src0, src1, capabilities.sg_mat_k) : false; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); supports_op = max_kv_tile > 0; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 6d5d69fb8de7..c3abb3082b9d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,12 +4,20 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; -#ifdef KV_F32 -#define KV_TYPE f32 -#elif defined(KV_Q4_0) || defined(KV_Q8_0) -#define KV_TYPE u32 +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 #else -#define KV_TYPE f16 +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 +#else +#define V_TYPE f16 #endif // Default values @@ -35,20 +43,35 @@ enable chromium_experimental_subgroup_matrix; #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) #define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) // number of quantized elements processed per thread -#if defined(KV_Q4_0) -#define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights -#define F16_PER_BLOCK 9 -#define BLOCK_SIZE_BYTES 18u -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights -#define F16_PER_BLOCK 17 -#define BLOCK_SIZE_BYTES 34u -#define WEIGHTS_PER_F16 2 +#if defined(K_Q4_0) +#define K_NQ 16 +#define K_F16_PER_BLOCK 9 +#define K_BLOCK_SIZE_BYTES 18u +#define K_WEIGHTS_PER_F16 4 +#elif defined(K_Q8_0) +#define K_NQ 8 +#define K_F16_PER_BLOCK 17 +#define K_BLOCK_SIZE_BYTES 34u +#define K_WEIGHTS_PER_F16 2 +#endif +#if defined(K_Q4_0) || defined(K_Q8_0) +#define K_F16_PER_THREAD (K_NQ / K_WEIGHTS_PER_F16) +#endif + +#if defined(V_Q4_0) +#define V_NQ 16 +#define V_F16_PER_BLOCK 9 +#define V_BLOCK_SIZE_BYTES 18u +#define V_WEIGHTS_PER_F16 4 +#elif defined(V_Q8_0) +#define V_NQ 8 +#define V_F16_PER_BLOCK 17 +#define V_BLOCK_SIZE_BYTES 34u +#define V_WEIGHTS_PER_F16 2 +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +#define V_F16_PER_THREAD (V_NQ / V_WEIGHTS_PER_F16) #endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) // Ok not to put these in a define block, compiler will remove if unused fn get_byte(value: u32, index: u32) -> u32 { @@ -59,7 +82,7 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } -#if defined(KV_Q4_0) || defined(KV_Q8_0) +#if defined(K_Q4_0) || defined(K_Q8_0) fn load_k_u16_at(byte_offset: u32) -> u32 { let word = K[byte_offset / 4u]; let shift = (byte_offset & 2u) * 8u; @@ -76,7 +99,9 @@ fn load_k_u32_at(byte_offset: u32) -> u32 { let hi = K[word_idx + 1u]; return (lo >> shift) | (hi << (32u - shift)); } +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) fn load_v_u16_at(byte_offset: u32) -> u32 { let word = V[byte_offset / 4u]; let shift = (byte_offset & 2u) * 8u; @@ -139,11 +164,11 @@ struct Params { @group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP -@group(0) @binding(1) var K: array; +@group(0) @binding(1) var K: array; #define V K #else -@group(0) @binding(1) var K: array; -@group(0) @binding(2) var V: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array; #endif #if defined(MASK) && defined(SINKS) @@ -238,7 +263,7 @@ fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32 return (*buf)[scalar_index >> 2u]; } -fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { +fn load_kx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { return (*buf)[scalar_index >> 2u]; } @@ -317,10 +342,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +#if defined(K_Q4_0) + for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; let k_row = blck_idx / BLOCKS_K; let global_k_row = kv_tile + k_row; let block_k = blck_idx % BLOCKS_K; @@ -328,9 +353,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { @@ -344,10 +369,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +#elif defined(K_Q8_0) + for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; let k_row = blck_idx / BLOCKS_K; let global_k_row = kv_tile + k_row; let block_k = blck_idx % BLOCKS_K; @@ -355,9 +380,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { @@ -520,10 +545,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +#if defined(V_Q4_0) + for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; let v_row = blck_idx / BLOCKS_V; let global_v_row = kv_tile + v_row; let block_k = blck_idx % BLOCKS_V; @@ -531,9 +556,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { @@ -547,10 +572,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +#elif defined(V_Q8_0) + for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; let v_row = blck_idx / BLOCKS_V; let global_v_row = kv_tile + v_row; let block_k = blck_idx % BLOCKS_V; @@ -558,9 +583,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 4133f0ab5644..dfae83bcb481 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -7,10 +7,16 @@ enable subgroups; #define Q_TYPE f32 #endif -#ifdef KV_F32 -#define KV_TYPE f32 +#ifdef K_F32 +#define K_TYPE f32 #else -#define KV_TYPE f16 +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#else +#define V_TYPE f16 #endif #ifdef DST_F16 @@ -64,11 +70,11 @@ struct Params { @group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP -@group(0) @binding(1) var K: array>; +@group(0) @binding(1) var K: array>; #define V K #else -@group(0) @binding(1) var K: array>; -@group(0) @binding(2) var V: array>; +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; #endif #if defined(MASK) && defined(SINKS) @@ -123,8 +129,8 @@ const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGRO const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; var q_shmem: array; -var kv_shmem: array; -var p_shmem: array; +var kv_shmem: array; +var p_shmem: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @@ -213,10 +219,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); + kv_shmem[kv_off + 0u] = f16(k4.x); + kv_shmem[kv_off + 1u] = f16(k4.y); + kv_shmem[kv_off + 2u] = f16(k4.z); + kv_shmem[kv_off + 3u] = f16(k4.w); } workgroupBarrier(); @@ -239,7 +245,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, q_shmem[q_off + 2u], q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4( + let kv = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], @@ -271,7 +277,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); + p_shmem[subgroup_p_offset + kv_local] = f16(p); local_sum += p; } } @@ -285,10 +291,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); + kv_shmem[kv_off + 0u] = f16(v4.x); + kv_shmem[kv_off + 1u] = f16(v4.y); + kv_shmem[kv_off + 2u] = f16(v4.z); + kv_shmem[kv_off + 3u] = f16(v4.w); } workgroupBarrier(); @@ -306,14 +312,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var acc = out_regs[reg_idx]; for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { - let p = p_shmem[subgroup_p_offset + kv_local]; + let p = f32(p_shmem[subgroup_p_offset + kv_local]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4( + let v4 = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += f32(p) * vec4(v4); + acc += p * vec4(v4); } out_regs[reg_idx] = acc; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 30ebbebe7720..02e2dad4dbf7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -2,10 +2,16 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -#ifdef KV_F32 -#define KV_TYPE f32 +#ifdef K_F32 +#define K_TYPE f32 #else -#define KV_TYPE f16 +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#else +#define V_TYPE f16 #endif #ifdef Q_F16 @@ -35,16 +41,31 @@ enable subgroups; #define BLOCK_SIZE 32 #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) #define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -#if defined(KV_Q4_0) -#define NQ 16 -#define F16_PER_BLOCK 9 -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -#define F16_PER_BLOCK 17 -#define WEIGHTS_PER_F16 2 +#if defined(K_Q4_0) +#define K_NQ 16 +#define K_F16_PER_BLOCK 9 +#define K_WEIGHTS_PER_F16 4 +#elif defined(K_Q8_0) +#define K_NQ 8 +#define K_F16_PER_BLOCK 17 +#define K_WEIGHTS_PER_F16 2 +#endif +#if defined(K_Q4_0) || defined(K_Q8_0) +#define K_F16_PER_THREAD (K_NQ / K_WEIGHTS_PER_F16) +#endif + +#if defined(V_Q4_0) +#define V_NQ 16 +#define V_F16_PER_BLOCK 9 +#define V_WEIGHTS_PER_F16 4 +#elif defined(V_Q8_0) +#define V_NQ 8 +#define V_F16_PER_BLOCK 17 +#define V_WEIGHTS_PER_F16 2 +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +#define V_F16_PER_THREAD (V_NQ / V_WEIGHTS_PER_F16) #endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; @@ -103,22 +124,22 @@ struct Params { @group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var K: array; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var K: array; #else -@group(0) @binding(1) var K: array>; +@group(0) @binding(1) var K: array>; #endif #define V K #else -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var K: array; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var K: array; #else -@group(0) @binding(1) var K: array>; +@group(0) @binding(1) var K: array>; #endif -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(2) var V: array; +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var V: array; #else -@group(0) @binding(2) var V: array>; +@group(0) @binding(2) var V: array>; #endif #endif #if defined(MASK) && defined(SINKS) @@ -324,10 +345,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +#if defined(K_Q4_0) + for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; let k_row = blck_idx / BLOCKS_K; let global_k_row = kv_tile + k_row; let block_k = blck_idx % BLOCKS_K; @@ -335,9 +356,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; + let base_idx = global_block_idx * K_F16_PER_BLOCK; let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { let q_0 = K[base_idx + 1u + block_offset + j]; let q_1 = K[base_idx + 1u + block_offset + j + 1]; let q_packed = bitcast(vec2(q_0, q_1)); @@ -352,10 +373,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { +#elif defined(K_Q8_0) + for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; let k_row = blck_idx / BLOCKS_K; let global_k_row = kv_tile + k_row; let block_k = blck_idx % BLOCKS_K; @@ -363,9 +384,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; + let base_idx = global_block_idx * K_F16_PER_BLOCK; let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { let q_0 = K[base_idx + 1u + block_offset + j]; let q_1 = K[base_idx + 1u + block_offset + j + 1]; let q_packed = bitcast(vec2(q_0, q_1)); @@ -388,7 +409,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; let vec_idx = (global_k_row_offset + k_col) >> 2u; - let k4 = select(vec4(0.0), K[vec_idx], in_bounds); + let k4 = select(vec4(0.0), K[vec_idx], in_bounds); kv_shmem[elem_idx + 0u] = f32(k4.x); kv_shmem[elem_idx + 1u] = f32(k4.y); kv_shmem[elem_idx + 2u] = f32(k4.z); @@ -510,10 +531,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +#if defined(V_Q4_0) + for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; let v_row = blck_idx / BLOCKS_V; let global_v_row = kv_tile + v_row; let block_k = blck_idx % BLOCKS_V; @@ -521,9 +542,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; + let base_idx = global_block_idx * V_F16_PER_BLOCK; let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { let q_0 = V[base_idx + 1u + block_offset + j]; let q_1 = V[base_idx + 1u + block_offset + j + 1]; let q_packed = bitcast(vec2(q_0, q_1)); @@ -538,10 +559,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { +#elif defined(V_Q8_0) + for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; let v_row = blck_idx / BLOCKS_V; let global_v_row = kv_tile + v_row; let block_k = blck_idx % BLOCKS_V; @@ -549,9 +570,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; + let base_idx = global_block_idx * V_F16_PER_BLOCK; let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { + for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { let q_0 = V[base_idx + 1u + block_offset + j]; let q_1 = V[base_idx + 1u + block_offset + j + 1]; let q_packed = bitcast(vec2(q_0, q_1)); @@ -574,7 +595,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; let vec_idx = (global_v_row_offset + v_col) >> 2u; - let v4 = select(vec4(0.0), V[vec_idx], in_bounds); + let v4 = select(vec4(0.0), V[vec_idx], in_bounds); kv_shmem[elem_idx + 0u] = f32(v4.x); kv_shmem[elem_idx + 1u] = f32(v4.y); kv_shmem[elem_idx + 2u] = f32(v4.z); From 94a1fddef6d1f0b1f2ba6d5e8b1bf7883de00df4 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 27 May 2026 21:34:51 -0700 Subject: [PATCH 4/7] Refactor and abstract quantization logic for flash_attn and mul_mat --- ggml/src/ggml-webgpu/pre_wgsl.hpp | 1406 +++++++++-------- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 147 +- .../wgsl-shaders/flash_attn_quant_blocks.tmpl | 91 ++ .../wgsl-shaders/flash_attn_vec_split.wgsl | 189 +-- .../wgsl-shaders/mul_mat_decls.tmpl | 20 +- .../wgsl-shaders/quant_inner_loops.tmpl | 19 + 6 files changed, 919 insertions(+), 953 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp index 4d4359463cac..702e6cda9391 100644 --- a/ggml/src/ggml-webgpu/pre_wgsl.hpp +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -17,762 +17,772 @@ namespace pre_wgsl { // Options //============================================================== struct Options { - std::string include_path = "."; - std::vector macros; + std::string include_path = "."; + std::vector macros; }; //============================================================== // Utility: trim //============================================================== -static std::string trim(const std::string & s) { - size_t a = 0; - while (a < s.size() && std::isspace((unsigned char) s[a])) { - a++; - } - size_t b = s.size(); - while (b > a && std::isspace((unsigned char) s[b - 1])) { - b--; - } - return s.substr(a, b - a); +static std::string trim(const std::string &s) { + size_t a = 0; + while (a < s.size() && std::isspace((unsigned char)s[a])) + a++; + size_t b = s.size(); + while (b > a && std::isspace((unsigned char)s[b - 1])) + b--; + return s.substr(a, b - a); } -static std::string trim_value(std::istream & is) { - std::string str; - std::getline(is, str); - return trim(str); +static std::string trim_value(std::istream &is) { + std::ostringstream ss; + ss << is.rdbuf(); + return trim(ss.str()); } static bool isIdentChar(char c) { - return std::isalnum(static_cast(c)) || c == '_'; + return std::isalnum(static_cast(c)) || c == '_'; } -static std::string expandMacrosRecursiveInternal(const std::string & line, - const std::unordered_map & macros, - std::unordered_set & visiting); - -static std::string expandMacroValue(const std::string & name, - const std::unordered_map & macros, - std::unordered_set & visiting) { - if (visiting.count(name)) { - throw std::runtime_error("Recursive macro: " + name); - } - visiting.insert(name); +static bool endsWithContinuation(const std::string &line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char)line[i - 1])) + i--; + return i > 0 && line[i - 1] == '\\'; +} - auto it = macros.find(name); - if (it == macros.end()) { - visiting.erase(name); - return name; - } +static void stripContinuation(std::string &line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char)line[i - 1])) + i--; + if (i > 0 && line[i - 1] == '\\') { + line.erase(i - 1); + } +} - const std::string & value = it->second; - if (value.empty()) { - visiting.erase(name); - return ""; - } +static std::string expandMacrosRecursiveInternal( + const std::string &line, + const std::unordered_map ¯os, + std::unordered_set &visiting); + +static std::string +expandMacroValue(const std::string &name, + const std::unordered_map ¯os, + std::unordered_set &visiting) { + if (visiting.count(name)) + throw std::runtime_error("Recursive macro: " + name); + visiting.insert(name); + + auto it = macros.find(name); + if (it == macros.end()) { + visiting.erase(name); + return name; + } - std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting); + const std::string &value = it->second; + if (value.empty()) { visiting.erase(name); - return expanded; -} + return ""; + } -static std::string expandMacrosRecursiveInternal(const std::string & line, - const std::unordered_map & macros, - std::unordered_set & visiting) { - std::string result; - result.reserve(line.size()); - - size_t i = 0; - while (i < line.size()) { - if (isIdentChar(line[i])) { - size_t start = i; - while (i < line.size() && isIdentChar(line[i])) { - i++; - } - std::string token = line.substr(start, i - start); - - auto it = macros.find(token); - if (it != macros.end()) { - result += expandMacroValue(token, macros, visiting); - } else { - result += token; - } - } else { - result += line[i]; - i++; - } - } + std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting); + visiting.erase(name); + return expanded; +} - return result; +static std::string expandMacrosRecursiveInternal( + const std::string &line, + const std::unordered_map ¯os, + std::unordered_set &visiting) { + std::string result; + result.reserve(line.size()); + + size_t i = 0; + while (i < line.size()) { + if (isIdentChar(line[i])) { + size_t start = i; + while (i < line.size() && isIdentChar(line[i])) { + i++; + } + std::string token = line.substr(start, i - start); + + auto it = macros.find(token); + if (it != macros.end()) { + result += expandMacroValue(token, macros, visiting); + } else { + result += token; + } + } else { + result += line[i]; + i++; + } + } + + return result; } -static std::string expandMacrosRecursive(const std::string & line, - const std::unordered_map & macros) { - std::unordered_set visiting; - return expandMacrosRecursiveInternal(line, macros, visiting); +static std::string expandMacrosRecursive( + const std::string &line, + const std::unordered_map ¯os) { + std::unordered_set visiting; + return expandMacrosRecursiveInternal(line, macros, visiting); } //============================================================== // Tokenizer for expressions in #if/#elif //============================================================== class ExprLexer { - public: - enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN }; +public: + enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN }; - struct Tok { - Kind kind; - std::string text; - }; + struct Tok { + Kind kind; + std::string text; + }; - explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} + explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} - Tok next() { - skipWS(); - if (pos >= src.size()) { - return { END, "" }; - } - - char c = src[pos]; - - // number - if (std::isdigit((unsigned char) c)) { - size_t start = pos; - while (pos < src.size() && std::isdigit((unsigned char) src[pos])) { - pos++; - } - return { NUMBER, std::string(src.substr(start, pos - start)) }; - } + Tok next() { + skipWS(); + if (pos >= src.size()) + return {END, ""}; - // identifier - if (std::isalpha((unsigned char) c) || c == '_') { - size_t start = pos; - while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) { - pos++; - } - return { IDENT, std::string(src.substr(start, pos - start)) }; - } - - if (c == '(') { - pos++; - return { LPAREN, "(" }; - } - if (c == ')') { - pos++; - return { RPAREN, ")" }; - } + char c = src[pos]; - // multi-char operators - static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" }; - for (auto op : two_ops) { - if (src.substr(pos, 2) == op) { - pos += 2; - return { OP, std::string(op) }; - } - } - - // single-char operators - if (std::string("+-*/%<>!").find(c) != std::string::npos) { - pos++; - return { OP, std::string(1, c) }; - } - - // unexpected + // number + if (std::isdigit((unsigned char)c)) { + size_t start = pos; + while (pos < src.size() && std::isdigit((unsigned char)src[pos])) pos++; - return { END, "" }; + return {NUMBER, std::string(src.substr(start, pos - start))}; } - private: - std::string_view src; - size_t pos; - - void skipWS() { - while (pos < src.size() && std::isspace((unsigned char) src[pos])) { - pos++; - } - } -}; - -//============================================================== -// Expression Parser (recursive descent) -//============================================================== -class ExprParser { - public: - ExprParser(std::string_view expr, - const std::unordered_map & macros, - std::unordered_set & visiting) : - lex(expr), - macros(macros), - visiting(visiting) { - advance(); + // identifier + if (std::isalpha((unsigned char)c) || c == '_') { + size_t start = pos; + while (pos < src.size() && + (std::isalnum((unsigned char)src[pos]) || src[pos] == '_')) + pos++; + return {IDENT, std::string(src.substr(start, pos - start))}; } - int parse() { return parseLogicalOr(); } - - private: - ExprLexer lex; - ExprLexer::Tok tok; - const std::unordered_map & macros; - std::unordered_set & visiting; - - void advance() { tok = lex.next(); } - - bool acceptOp(const std::string & s) { - if (tok.kind == ExprLexer::OP && tok.text == s) { - advance(); - return true; - } - return false; + if (c == '(') { + pos++; + return {LPAREN, "("}; } - - bool acceptKind(ExprLexer::Kind k) { - if (tok.kind == k) { - advance(); - return true; - } - return false; + if (c == ')') { + pos++; + return {RPAREN, ")"}; } - int parseLogicalOr() { - int v = parseLogicalAnd(); - while (acceptOp("||")) { - int rhs = parseLogicalAnd(); - v = (v || rhs); - } - return v; + // multi-char operators + static const char *two_ops[] = { + "==", "!=", "<=", ">=", "&&", "||", "<<", ">>"}; + for (auto op : two_ops) { + if (src.substr(pos, 2) == op) { + pos += 2; + return {OP, std::string(op)}; + } } - int parseLogicalAnd() { - int v = parseEquality(); - while (acceptOp("&&")) { - int rhs = parseEquality(); - v = (v && rhs); - } - return v; - } - - int parseEquality() { - int v = parseRelational(); - for (;;) { - if (acceptOp("==")) { - int rhs = parseRelational(); - v = (v == rhs); - } else if (acceptOp("!=")) { - int rhs = parseRelational(); - v = (v != rhs); - } else { - break; - } - } - return v; - } - - int parseRelational() { - int v = parseShift(); - for (;;) { - if (acceptOp("<")) { - int rhs = parseShift(); - v = (v < rhs); - } else if (acceptOp(">")) { - int rhs = parseShift(); - v = (v > rhs); - } else if (acceptOp("<=")) { - int rhs = parseShift(); - v = (v <= rhs); - } else if (acceptOp(">=")) { - int rhs = parseShift(); - v = (v >= rhs); - } else { - break; - } - } - return v; - } - - int parseShift() { - int v = parseAdd(); - for (;;) { - if (acceptOp("<<")) { - int rhs = parseAdd(); - v = (v << rhs); - } else if (acceptOp(">>")) { - int rhs = parseAdd(); - v = (v >> rhs); - } else { - break; - } - } - return v; - } - - int parseAdd() { - int v = parseMult(); - for (;;) { - if (acceptOp("+")) { - int rhs = parseMult(); - v = (v + rhs); - } else if (acceptOp("-")) { - int rhs = parseMult(); - v = (v - rhs); - } else { - break; - } - } - return v; - } - - int parseMult() { - int v = parseUnary(); - for (;;) { - if (acceptOp("*")) { - int rhs = parseUnary(); - v = (v * rhs); - } else if (acceptOp("/")) { - int rhs = parseUnary(); - v = (rhs == 0 ? 0 : v / rhs); - } else if (acceptOp("%")) { - int rhs = parseUnary(); - v = (rhs == 0 ? 0 : v % rhs); - } else { - break; - } - } - return v; - } - - int parseUnary() { - if (acceptOp("!")) { - return !parseUnary(); - } - if (acceptOp("-")) { - return -parseUnary(); - } - if (acceptOp("+")) { - return +parseUnary(); - } - return parsePrimary(); + // single-char operators + if (std::string("+-*/%<>!").find(c) != std::string::npos) { + pos++; + return {OP, std::string(1, c)}; } - int parsePrimary() { - // '(' expr ')' - if (acceptKind(ExprLexer::LPAREN)) { - int v = parse(); - if (!acceptKind(ExprLexer::RPAREN)) { - throw std::runtime_error("missing ')'"); - } - return v; - } + // unexpected + pos++; + return {END, ""}; + } - // number - if (tok.kind == ExprLexer::NUMBER) { - int v = std::stoi(tok.text); - advance(); - return v; - } +private: + std::string_view src; + size_t pos; - // defined(identifier) - if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { - advance(); - if (acceptKind(ExprLexer::LPAREN)) { - if (tok.kind != ExprLexer::IDENT) { - throw std::runtime_error("expected identifier in defined()"); - } - std::string name = tok.text; - advance(); - if (!acceptKind(ExprLexer::RPAREN)) { - throw std::runtime_error("missing ) in defined()"); - } - return macros.count(name) ? 1 : 0; - } else { - // defined NAME - if (tok.kind != ExprLexer::IDENT) { - throw std::runtime_error("expected identifier in defined NAME"); - } - std::string name = tok.text; - advance(); - return macros.count(name) ? 1 : 0; - } - } + void skipWS() { + while (pos < src.size() && std::isspace((unsigned char)src[pos])) + pos++; + } +}; - // identifier -> treat as integer, if defined use its value else 0 - if (tok.kind == ExprLexer::IDENT) { - std::string name = tok.text; - advance(); - auto it = macros.find(name); - if (it == macros.end()) { - return 0; - } - if (it->second.empty()) { - return 1; - } - return evalMacroExpression(name, it->second); - } +//============================================================== +// Expression Parser (recursive descent) +//============================================================== +class ExprParser { +public: + ExprParser(std::string_view expr, + const std::unordered_map ¯os, + std::unordered_set &visiting) + : lex(expr), macros(macros), visiting(visiting) { + advance(); + } + + int parse() { return parseLogicalOr(); } + +private: + ExprLexer lex; + ExprLexer::Tok tok; + const std::unordered_map ¯os; + std::unordered_set &visiting; + + void advance() { tok = lex.next(); } + + bool acceptOp(const std::string &s) { + if (tok.kind == ExprLexer::OP && tok.text == s) { + advance(); + return true; + } + return false; + } + + bool acceptKind(ExprLexer::Kind k) { + if (tok.kind == k) { + advance(); + return true; + } + return false; + } + + int parseLogicalOr() { + int v = parseLogicalAnd(); + while (acceptOp("||")) { + int rhs = parseLogicalAnd(); + v = (v || rhs); + } + return v; + } + + int parseLogicalAnd() { + int v = parseEquality(); + while (acceptOp("&&")) { + int rhs = parseEquality(); + v = (v && rhs); + } + return v; + } + + int parseEquality() { + int v = parseRelational(); + for (;;) { + if (acceptOp("==")) { + int rhs = parseRelational(); + v = (v == rhs); + } else if (acceptOp("!=")) { + int rhs = parseRelational(); + v = (v != rhs); + } else + break; + } + return v; + } + + int parseRelational() { + int v = parseShift(); + for (;;) { + if (acceptOp("<")) { + int rhs = parseShift(); + v = (v < rhs); + } else if (acceptOp(">")) { + int rhs = parseShift(); + v = (v > rhs); + } else if (acceptOp("<=")) { + int rhs = parseShift(); + v = (v <= rhs); + } else if (acceptOp(">=")) { + int rhs = parseShift(); + v = (v >= rhs); + } else + break; + } + return v; + } + + int parseShift() { + int v = parseAdd(); + for (;;) { + if (acceptOp("<<")) { + int rhs = parseAdd(); + v = (v << rhs); + } else if (acceptOp(">>")) { + int rhs = parseAdd(); + v = (v >> rhs); + } else + break; + } + return v; + } + + int parseAdd() { + int v = parseMult(); + for (;;) { + if (acceptOp("+")) { + int rhs = parseMult(); + v = (v + rhs); + } else if (acceptOp("-")) { + int rhs = parseMult(); + v = (v - rhs); + } else + break; + } + return v; + } + + int parseMult() { + int v = parseUnary(); + for (;;) { + if (acceptOp("*")) { + int rhs = parseUnary(); + v = (v * rhs); + } else if (acceptOp("/")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v / rhs); + } else if (acceptOp("%")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v % rhs); + } else + break; + } + return v; + } + + int parseUnary() { + if (acceptOp("!")) + return !parseUnary(); + if (acceptOp("-")) + return -parseUnary(); + if (acceptOp("+")) + return +parseUnary(); + return parsePrimary(); + } + + int parsePrimary() { + // '(' expr ')' + if (acceptKind(ExprLexer::LPAREN)) { + int v = parse(); + if (!acceptKind(ExprLexer::RPAREN)) + throw std::runtime_error("missing ')'"); + return v; + } + + // number + if (tok.kind == ExprLexer::NUMBER) { + int v = std::stoi(tok.text); + advance(); + return v; + } + + // defined(identifier) + if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { + advance(); + if (acceptKind(ExprLexer::LPAREN)) { + if (tok.kind != ExprLexer::IDENT) + throw std::runtime_error("expected identifier in defined()"); + std::string name = tok.text; + advance(); + if (!acceptKind(ExprLexer::RPAREN)) + throw std::runtime_error("missing ) in defined()"); + return macros.count(name) ? 1 : 0; + } else { + // defined NAME + if (tok.kind != ExprLexer::IDENT) + throw std::runtime_error("expected identifier in defined NAME"); + std::string name = tok.text; + advance(); + return macros.count(name) ? 1 : 0; + } + } - // unexpected + // identifier -> treat as integer, if defined use its value else 0 + if (tok.kind == ExprLexer::IDENT) { + std::string name = tok.text; + advance(); + auto it = macros.find(name); + if (it == macros.end()) return 0; + if (it->second.empty()) + return 1; + return evalMacroExpression(name, it->second); } - int evalMacroExpression(const std::string & name, const std::string & value) { - if (visiting.count(name)) { - throw std::runtime_error("Recursive macro: " + name); - } + // unexpected + return 0; + } - visiting.insert(name); - ExprParser ep(value, macros, visiting); - int v = ep.parse(); - visiting.erase(name); - return v; - } + int evalMacroExpression(const std::string &name, const std::string &value) { + if (visiting.count(name)) + throw std::runtime_error("Recursive macro: " + name); + + visiting.insert(name); + ExprParser ep(value, macros, visiting); + int v = ep.parse(); + visiting.erase(name); + return v; + } }; //============================================================== // Preprocessor //============================================================== class Preprocessor { - public: - explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) { - // Treat empty include path as current directory - if (opts_.include_path.empty()) { - opts_.include_path = "."; - } - parseMacroDefinitions(opts_.macros); - } - - std::string preprocess_file(const std::string & filename, const std::vector & additional_macros = {}) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - buildMacros(additional_macros, macros, predefined); - - std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All); - return result; - } - - std::string preprocess(const std::string & contents, const std::vector & additional_macros = {}) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - buildMacros(additional_macros, macros, predefined); - - std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All); - return result; - } - - std::string preprocess_includes_file(const std::string & filename) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly); - return result; - } - - std::string preprocess_includes(const std::string & contents) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly); - return result; - } - - private: - Options opts_; - std::unordered_map global_macros; - - enum class DirectiveMode { All, IncludesOnly }; - - struct Cond { - bool parent_active; - bool active; - bool taken; - }; - - //---------------------------------------------------------- - // Parse macro definitions into global_macros - //---------------------------------------------------------- - void parseMacroDefinitions(const std::vector & macro_defs) { - for (const auto & def : macro_defs) { - size_t eq_pos = def.find('='); - if (eq_pos != std::string::npos) { - // Format: NAME=VALUE - std::string name = trim(def.substr(0, eq_pos)); - std::string value = trim(def.substr(eq_pos + 1)); - global_macros[name] = value; - } else { - // Format: NAME - std::string name = trim(def); - global_macros[name] = ""; - } - } - } - - //---------------------------------------------------------- - // Build combined macro map and predefined set for a preprocessing operation - //---------------------------------------------------------- - void buildMacros(const std::vector & additional_macros, - std::unordered_map & macros, - std::unordered_set & predefined) { - macros = global_macros; - predefined.clear(); - - for (const auto & [name, value] : global_macros) { - predefined.insert(name); - } - - for (const auto & def : additional_macros) { - size_t eq_pos = def.find('='); - std::string name, value; - if (eq_pos != std::string::npos) { - name = trim(def.substr(0, eq_pos)); - value = trim(def.substr(eq_pos + 1)); - } else { - name = trim(def); - value = ""; - } - - // Add to macros map (will override global if same name) - macros[name] = value; - predefined.insert(name); - } - } - - //---------------------------------------------------------- - // Helpers - //---------------------------------------------------------- - std::string loadFile(const std::string & fname) { - std::ifstream f(fname); - if (!f.is_open()) { - throw std::runtime_error("Could not open file: " + fname); - } - std::stringstream ss; - ss << f.rdbuf(); - return ss.str(); - } - - bool condActive(const std::vector & cond) const { - if (cond.empty()) { - return true; - } - return cond.back().active; - } - - //---------------------------------------------------------- - // Process a file - //---------------------------------------------------------- - std::string processFile(const std::string & name, - std::unordered_map & macros, - const std::unordered_set & predefined_macros, - std::unordered_set & include_stack, - DirectiveMode mode) { - if (include_stack.count(name)) { - throw std::runtime_error("Recursive include: " + name); - } - - include_stack.insert(name); - std::string shader_code = loadFile(name); - std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode); - include_stack.erase(name); - return out; - } - - std::string processIncludeFile(const std::string & fname, - std::unordered_map & macros, - const std::unordered_set & predefined_macros, - std::unordered_set & include_stack, - DirectiveMode mode) { - std::string full_path = opts_.include_path + "/" + fname; - return processFile(full_path, macros, predefined_macros, include_stack, mode); - } - - //---------------------------------------------------------- - // Process text - //---------------------------------------------------------- - std::string processString(const std::string & shader_code, - std::unordered_map & macros, - const std::unordered_set & predefined_macros, - std::unordered_set & include_stack, - DirectiveMode mode) { - std::vector cond; // Conditional stack for this shader - std::stringstream out; - std::istringstream in(shader_code); - std::string line; - - while (std::getline(in, line)) { - std::string t = trim(line); - - if (!t.empty() && t[0] == '#') { - bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); - if (mode == DirectiveMode::IncludesOnly && !handled) { - out << line << "\n"; - } - } else { - if (mode == DirectiveMode::IncludesOnly) { - out << line << "\n"; - } else if (condActive(cond)) { - // Expand macros in the line before outputting - std::string expanded = expandMacrosRecursive(line, macros); - out << expanded << "\n"; - } - } - } - - if (mode == DirectiveMode::All && !cond.empty()) { - throw std::runtime_error("Unclosed #if directive"); - } - - return out.str(); - } - - //---------------------------------------------------------- - // Directive handler - //---------------------------------------------------------- - bool handleDirective(const std::string & t, - std::stringstream & out, - std::unordered_map & macros, - const std::unordered_set & predefined_macros, - std::vector & cond, - std::unordered_set & include_stack, - DirectiveMode mode) { - // split into tokens - std::string body = t.substr(1); - std::istringstream iss(body); - std::string cmd; - iss >> cmd; - - if (cmd == "include") { - if (mode == DirectiveMode::All && !condActive(cond)) { - return true; - } - std::string file; - iss >> file; - if (file.size() >= 2 && file.front() == '"' && file.back() == '"') { - file = file.substr(1, file.size() - 2); - } - out << processIncludeFile(file, macros, predefined_macros, include_stack, mode); - return true; - } - +public: + explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) { + // Treat empty include path as current directory + if (opts_.include_path.empty()) { + opts_.include_path = "."; + } + parseMacroDefinitions(opts_.macros); + } + + std::string + preprocess_file(const std::string &filename, + const std::vector &additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processFile(filename, macros, predefined, + include_stack, DirectiveMode::All); + return result; + } + + std::string + preprocess(const std::string &contents, + const std::vector &additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processString(contents, macros, predefined, + include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess_includes_file(const std::string &filename) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = + processFile(filename, macros, predefined, include_stack, + DirectiveMode::IncludesOnly); + return result; + } + + std::string preprocess_includes(const std::string &contents) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = + processString(contents, macros, predefined, include_stack, + DirectiveMode::IncludesOnly); + return result; + } + +private: + Options opts_; + std::unordered_map global_macros; + + enum class DirectiveMode { All, IncludesOnly }; + + struct Cond { + bool parent_active; + bool active; + bool taken; + }; + + //---------------------------------------------------------- + // Parse macro definitions into global_macros + //---------------------------------------------------------- + void parseMacroDefinitions(const std::vector ¯o_defs) { + for (const auto &def : macro_defs) { + size_t eq_pos = def.find('='); + if (eq_pos != std::string::npos) { + // Format: NAME=VALUE + std::string name = trim(def.substr(0, eq_pos)); + std::string value = trim(def.substr(eq_pos + 1)); + global_macros[name] = value; + } else { + // Format: NAME + std::string name = trim(def); + global_macros[name] = ""; + } + } + } + + //---------------------------------------------------------- + // Build combined macro map and predefined set for a preprocessing operation + //---------------------------------------------------------- + void buildMacros(const std::vector &additional_macros, + std::unordered_map ¯os, + std::unordered_set &predefined) { + macros = global_macros; + predefined.clear(); + + for (const auto &[name, value] : global_macros) { + predefined.insert(name); + } + + for (const auto &def : additional_macros) { + size_t eq_pos = def.find('='); + std::string name, value; + if (eq_pos != std::string::npos) { + name = trim(def.substr(0, eq_pos)); + value = trim(def.substr(eq_pos + 1)); + } else { + name = trim(def); + value = ""; + } + + // Add to macros map (will override global if same name) + macros[name] = value; + predefined.insert(name); + } + } + + //---------------------------------------------------------- + // Helpers + //---------------------------------------------------------- + std::string loadFile(const std::string &fname) { + std::ifstream f(fname); + if (!f.is_open()) + throw std::runtime_error("Could not open file: " + fname); + std::stringstream ss; + ss << f.rdbuf(); + return ss.str(); + } + + bool condActive(const std::vector &cond) const { + if (cond.empty()) + return true; + return cond.back().active; + } + + //---------------------------------------------------------- + // Process a file + //---------------------------------------------------------- + std::string + processFile(const std::string &name, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::unordered_set &include_stack, + DirectiveMode mode) { + if (include_stack.count(name)) + throw std::runtime_error("Recursive include: " + name); + + include_stack.insert(name); + std::string shader_code = loadFile(name); + std::string out = processString(shader_code, macros, predefined_macros, + include_stack, mode); + include_stack.erase(name); + return out; + } + + std::string + processIncludeFile(const std::string &fname, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::unordered_set &include_stack, + DirectiveMode mode) { + std::string full_path = opts_.include_path + "/" + fname; + return processFile(full_path, macros, predefined_macros, include_stack, + mode); + } + + //---------------------------------------------------------- + // Process text + //---------------------------------------------------------- + std::string + processString(const std::string &shader_code, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::unordered_set &include_stack, + DirectiveMode mode) { + std::vector cond; // Conditional stack for this shader + std::stringstream out; + std::istringstream in(shader_code); + std::string line; + + while (std::getline(in, line)) { + std::string logical = line; + std::string t = trim(logical); + if (!t.empty() && t[0] == '#') { + while (endsWithContinuation(logical)) { + stripContinuation(logical); + if (!std::getline(in, line)) + break; + logical += "\n"; + logical += line; + } + t = trim(logical); + } + + if (!t.empty() && t[0] == '#') { + bool handled = handleDirective(t, out, macros, predefined_macros, cond, + include_stack, mode); + if (mode == DirectiveMode::IncludesOnly && !handled) { + out << logical << "\n"; + } + } else { if (mode == DirectiveMode::IncludesOnly) { - return false; - } - - if (cmd == "define") { - if (!condActive(cond)) { - return true; - } - std::string name; - iss >> name; - // Don't override predefined macros from options - if (predefined_macros.count(name)) { - return true; - } - std::string value = trim_value(iss); - macros[name] = value; - return true; - } - - if (cmd == "undef") { - if (!condActive(cond)) { - return true; - } - std::string name; - iss >> name; - // Don't undef predefined macros from options - if (predefined_macros.count(name)) { - return true; - } - macros.erase(name); - return true; - } - - if (cmd == "ifdef") { - std::string name; - iss >> name; - bool p = condActive(cond); - bool v = macros.count(name); - cond.push_back({ p, p && v, p && v }); - return true; - } - - if (cmd == "ifndef") { - std::string name; - iss >> name; - bool p = condActive(cond); - bool v = !macros.count(name); - cond.push_back({ p, p && v, p && v }); - return true; - } - - if (cmd == "if") { - std::string expr = trim_value(iss); - bool p = condActive(cond); - bool v = false; - if (p) { - std::unordered_set visiting; - ExprParser ep(expr, macros, visiting); - v = ep.parse() != 0; - } - cond.push_back({ p, p && v, p && v }); - return true; - } - - if (cmd == "elif") { - std::string expr = trim_value(iss); - - if (cond.empty()) { - throw std::runtime_error("#elif without #if"); - } - - Cond & c = cond.back(); - if (!c.parent_active) { - c.active = false; - return true; - } - - if (c.taken) { - c.active = false; - return true; - } - - std::unordered_set visiting; - ExprParser ep(expr, macros, visiting); - bool v = ep.parse() != 0; - c.active = v; - if (v) { - c.taken = true; - } - return true; - } - - if (cmd == "else") { - if (cond.empty()) { - throw std::runtime_error("#else without #if"); - } - - Cond & c = cond.back(); - if (!c.parent_active) { - c.active = false; - return true; - } - if (c.taken) { - c.active = false; - } else { - c.active = true; - c.taken = true; - } - return true; - } - - if (cmd == "endif") { - if (cond.empty()) { - throw std::runtime_error("#endif without #if"); - } - cond.pop_back(); - return true; - } - - // Unknown directive - throw std::runtime_error("Unknown directive: #" + cmd); - } + out << logical << "\n"; + } else if (condActive(cond)) { + // Expand macros in the line before outputting + std::string expanded = expandMacrosRecursive(logical, macros); + out << expanded << "\n"; + } + } + } + + if (mode == DirectiveMode::All && !cond.empty()) + throw std::runtime_error("Unclosed #if directive"); + + return out.str(); + } + + //---------------------------------------------------------- + // Directive handler + //---------------------------------------------------------- + bool handleDirective(const std::string &t, std::stringstream &out, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::vector &cond, + std::unordered_set &include_stack, + DirectiveMode mode) { + // split into tokens + std::string body = t.substr(1); + std::istringstream iss(body); + std::string cmd; + iss >> cmd; + + if (cmd == "include") { + if (mode == DirectiveMode::All && !condActive(cond)) + return true; + std::string file; + iss >> file; + if (file.size() >= 2 && file.front() == '"' && file.back() == '"') + file = file.substr(1, file.size() - 2); + out << processIncludeFile(file, macros, predefined_macros, include_stack, + mode); + return true; + } + + if (mode == DirectiveMode::IncludesOnly) + return false; + + if (cmd == "define") { + if (!condActive(cond)) + return true; + std::string name; + iss >> name; + // Don't override predefined macros from options + if (predefined_macros.count(name)) + return true; + std::string value = trim_value(iss); + macros[name] = value; + return true; + } + + if (cmd == "undef") { + if (!condActive(cond)) + return true; + std::string name; + iss >> name; + // Don't undef predefined macros from options + if (predefined_macros.count(name)) + return true; + macros.erase(name); + return true; + } + + if (cmd == "ifdef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = macros.count(name); + cond.push_back({p, p && v, p && v}); + return true; + } + + if (cmd == "ifndef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = !macros.count(name); + cond.push_back({p, p && v, p && v}); + return true; + } + + if (cmd == "if") { + std::string expr = trim_value(iss); + bool p = condActive(cond); + bool v = false; + if (p) { + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + v = ep.parse() != 0; + } + cond.push_back({p, p && v, p && v}); + return true; + } + + if (cmd == "elif") { + std::string expr = trim_value(iss); + + if (cond.empty()) + throw std::runtime_error("#elif without #if"); + + Cond &c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + + if (c.taken) { + c.active = false; + return true; + } + + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + bool v = ep.parse() != 0; + c.active = v; + if (v) + c.taken = true; + return true; + } + + if (cmd == "else") { + if (cond.empty()) + throw std::runtime_error("#else without #if"); + + Cond &c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + if (c.taken) { + c.active = false; + } else { + c.active = true; + c.taken = true; + } + return true; + } + + if (cmd == "endif") { + if (cond.empty()) + throw std::runtime_error("#endif without #if"); + cond.pop_back(); + return true; + } + + // Unknown directive + throw std::runtime_error("Unknown directive: #" + cmd); + } }; -} // namespace pre_wgsl +} // namespace pre_wgsl -#endif // PRE_WGSL_HPP +#endif // PRE_WGSL_HPP diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index c3abb3082b9d..7410a53332be 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,6 +4,9 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; +#define BYTE_HELPERS +#include "common_decls.tmpl" + #ifdef K_F32 #define K_TYPE f32 #elif defined(K_Q4_0) || defined(K_Q8_0) @@ -47,41 +50,30 @@ enable chromium_experimental_subgroup_matrix; #define K_NQ 16 #define K_F16_PER_BLOCK 9 #define K_BLOCK_SIZE_BYTES 18u -#define K_WEIGHTS_PER_F16 4 +#define K_BYTES_PER_THREAD 8u +#define K_BYTES_PER_INNER_LOOP 4u #elif defined(K_Q8_0) -#define K_NQ 8 +#define K_NQ 16 #define K_F16_PER_BLOCK 17 #define K_BLOCK_SIZE_BYTES 34u -#define K_WEIGHTS_PER_F16 2 -#endif -#if defined(K_Q4_0) || defined(K_Q8_0) -#define K_F16_PER_THREAD (K_NQ / K_WEIGHTS_PER_F16) +#define K_BYTES_PER_THREAD 16u +#define K_BYTES_PER_INNER_LOOP 4u #endif #if defined(V_Q4_0) #define V_NQ 16 #define V_F16_PER_BLOCK 9 #define V_BLOCK_SIZE_BYTES 18u -#define V_WEIGHTS_PER_F16 4 +#define V_BYTES_PER_THREAD 8u +#define V_BYTES_PER_INNER_LOOP 4u #elif defined(V_Q8_0) -#define V_NQ 8 +#define V_NQ 16 #define V_F16_PER_BLOCK 17 #define V_BLOCK_SIZE_BYTES 34u -#define V_WEIGHTS_PER_F16 2 -#endif -#if defined(V_Q4_0) || defined(V_Q8_0) -#define V_F16_PER_THREAD (V_NQ / V_WEIGHTS_PER_F16) +#define V_BYTES_PER_THREAD 16u +#define V_BYTES_PER_INNER_LOOP 4u #endif -// Ok not to put these in a define block, compiler will remove if unused -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - #if defined(K_Q4_0) || defined(K_Q8_0) fn load_k_u16_at(byte_offset: u32) -> u32 { let word = K[byte_offset / 4u]; @@ -118,12 +110,12 @@ fn load_v_u32_at(byte_offset: u32) -> u32 { let hi = V[word_idx + 1u]; return (lo >> shift) | (hi << (32u - shift)); } +#endif fn f16_from_u16(bits: u32) -> f16 { let packed = unpack2x16float(bits); return f16(packed[0]); } -#endif struct Params { offset_q: u32, @@ -267,6 +259,13 @@ fn load_kx4(buf: ptr>, read_write>, scalar_index: u3 return (*buf)[scalar_index >> 2u]; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_blocks.tmpl" +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, @@ -343,57 +342,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load k tile into shared memory #if defined(K_Q4_0) - for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } + LOAD_K_Q4_0_TILE_BLOCK #elif defined(K_Q8_0) - for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } + LOAD_K_Q8_0_TILE_BLOCK #elif defined(KV_DIRECT) // Direct global loads for KV #else @@ -546,57 +497,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load v tile into shared memory #if defined(V_Q4_0) - for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } + LOAD_V_Q4_0_TILE_BLOCK #elif defined(V_Q8_0) - for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } + LOAD_V_Q8_0_TILE_BLOCK #elif defined(KV_DIRECT) // Direct global loads for KV #else diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl new file mode 100644 index 000000000000..cc43f9cc7b28 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl @@ -0,0 +1,91 @@ +#define LOAD_K_Q4_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ + let k_row = blck_idx / BLOCKS_K; \ + let global_k_row = kv_tile + k_row; \ + let block_k = blck_idx % BLOCKS_K; \ + let row_offset = k_row * HEAD_DIM_QK; \ + \ + if (global_k_row < params.seq_len_kv) { \ + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ + let q_packed = load_k_u32_at(q_byte_offset); \ + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ + } \ + } \ +} + +#define LOAD_K_Q8_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ + let k_row = blck_idx / BLOCKS_K; \ + let global_k_row = kv_tile + k_row; \ + let block_k = blck_idx % BLOCKS_K; \ + let row_offset = k_row * HEAD_DIM_QK; \ + \ + if (global_k_row < params.seq_len_kv) { \ + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ + let q_packed = load_k_u32_at(q_byte_offset); \ + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ + } \ + } \ +} + +#define LOAD_V_Q4_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ + let v_row = blck_idx / BLOCKS_V; \ + let global_v_row = kv_tile + v_row; \ + let block_k = blck_idx % BLOCKS_V; \ + let row_offset = v_row * HEAD_DIM_V; \ + \ + if (global_v_row < params.seq_len_kv) { \ + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ + let q_packed = load_v_u32_at(q_byte_offset); \ + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ + } \ + } \ +} + +#define LOAD_V_Q8_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ + let v_row = blck_idx / BLOCKS_V; \ + let global_v_row = kv_tile + v_row; \ + let block_k = blck_idx % BLOCKS_V; \ + let row_offset = v_row * HEAD_DIM_V; \ + \ + if (global_v_row < params.seq_len_kv) { \ + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ + let q_packed = load_v_u32_at(q_byte_offset); \ + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ + } \ + } \ +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 02e2dad4dbf7..d3efebac5d20 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -2,14 +2,21 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; +#define BYTE_HELPERS +#include "common_decls.tmpl" + #ifdef K_F32 #define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 #else #define K_TYPE f16 #endif #ifdef V_F32 #define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else #define V_TYPE f16 #endif @@ -44,35 +51,72 @@ enable subgroups; #if defined(K_Q4_0) #define K_NQ 16 #define K_F16_PER_BLOCK 9 -#define K_WEIGHTS_PER_F16 4 +#define K_BLOCK_SIZE_BYTES 18u +#define K_BYTES_PER_THREAD 8u +#define K_BYTES_PER_INNER_LOOP 4u #elif defined(K_Q8_0) -#define K_NQ 8 +#define K_NQ 16 #define K_F16_PER_BLOCK 17 -#define K_WEIGHTS_PER_F16 2 -#endif -#if defined(K_Q4_0) || defined(K_Q8_0) -#define K_F16_PER_THREAD (K_NQ / K_WEIGHTS_PER_F16) +#define K_BLOCK_SIZE_BYTES 34u +#define K_BYTES_PER_THREAD 16u +#define K_BYTES_PER_INNER_LOOP 4u #endif #if defined(V_Q4_0) #define V_NQ 16 #define V_F16_PER_BLOCK 9 -#define V_WEIGHTS_PER_F16 4 +#define V_BLOCK_SIZE_BYTES 18u +#define V_BYTES_PER_THREAD 8u +#define V_BYTES_PER_INNER_LOOP 4u #elif defined(V_Q8_0) -#define V_NQ 8 +#define V_NQ 16 #define V_F16_PER_BLOCK 17 -#define V_WEIGHTS_PER_F16 2 +#define V_BLOCK_SIZE_BYTES 34u +#define V_BYTES_PER_THREAD 16u +#define V_BYTES_PER_INNER_LOOP 4u #endif -#if defined(V_Q4_0) || defined(V_Q8_0) -#define V_F16_PER_THREAD (V_NQ / V_WEIGHTS_PER_F16) + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} #endif -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); } +#endif -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); } struct Params { @@ -265,6 +309,13 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) return v; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f32 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_blocks.tmpl" +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, @@ -346,59 +397,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load k tile into shared memory #if defined(K_Q4_0) - for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * K_F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } + LOAD_K_Q4_0_TILE_BLOCK #elif defined(K_Q8_0) - for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / K_WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * K_F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < K_F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } + LOAD_K_Q8_0_TILE_BLOCK #elif defined(KV_DIRECT) // Direct global loads for KV #else @@ -532,59 +533,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load v tile into shared memory #if defined(V_Q4_0) - for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * V_F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } + LOAD_V_Q4_0_TILE_BLOCK #elif defined(V_Q8_0) - for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / V_WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * V_F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < V_F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } + LOAD_V_Q8_0_TILE_BLOCK #elif defined(KV_DIRECT) // Direct global loads for KV #else diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index eb2a8368f438..72991504dd0c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -25,6 +25,10 @@ fn store_shmem(val: f16, idx: u32) { } #endif // SCALAR +#define QUANT_SHMEM shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" + #ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { @@ -124,14 +128,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; - } + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } } } @@ -314,12 +311,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte_i32(q_packed, k); - - let q_val = f16(q_byte) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; - } + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl new file mode 100644 index 000000000000..237308809402 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl @@ -0,0 +1,19 @@ +fn dequant_q4_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (QUANT_OUT_TYPE((q_byte >> 4) & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + let q_lo = (QUANT_OUT_TYPE(q_byte & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + QUANT_SHMEM[dst_idx + k] = q_lo; + QUANT_SHMEM[dst_idx + k + 16u] = q_hi; + } +} + +fn dequant_q8_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = QUANT_OUT_TYPE(q_byte) * scale; + QUANT_SHMEM[dst_idx + k] = q_val; + } +} From ddedf2fc8290e01947d175e44be24657d4ed4f3d Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 28 May 2026 11:02:46 -0700 Subject: [PATCH 5/7] Add quantization support to tile path --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 15 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 50 ++++-- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 79 +-------- .../wgsl-shaders/flash_attn_quant_blocks.tmpl | 91 ----------- .../flash_attn_quant_staging.tmpl | 150 ++++++++++++++++++ .../wgsl-shaders/flash_attn_tile.wgsl | 48 +++++- .../wgsl-shaders/flash_attn_vec_split.wgsl | 77 +-------- 7 files changed, 238 insertions(+), 272 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 971d5ad5439c..3f9b8155f3c7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -637,19 +637,19 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; } -inline bool ggml_webgpu_flash_attn_f16_vec4_aligned(const ggml_tensor * K, - size_t storage_offset_alignment) { +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, + size_t storage_offset_alignment) { const size_t alignment = std::max(1u, storage_offset_alignment); const uint32_t offset_elems = (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; } -inline bool ggml_webgpu_flash_attn_f16_vec4_aligned(const ggml_tensor * K, - const ggml_tensor * V, - size_t storage_offset_alignment) { - return ggml_webgpu_flash_attn_f16_vec4_aligned(K, storage_offset_alignment) && - ggml_webgpu_flash_attn_f16_vec4_aligned(V, storage_offset_alignment); +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) && + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); } inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, @@ -2742,7 +2742,6 @@ class ggml_webgpu_shader_lib { shader_src = wgsl_flash_attn_tile; defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); - defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.common.head_dim_qk, key.common.head_dim_v))); variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + std::to_string(context.max_subgroup_size); } else { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a47f1fd71f0c..e43834e0264d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1779,23 +1779,27 @@ static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & gl GGML_ASSERT(Q != nullptr); const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const bool k_f16_vec4_aligned = - K->type != GGML_TYPE_F16 || ggml_webgpu_flash_attn_f16_vec4_aligned(K, storage_offset_alignment); - const bool v_f16_vec4_aligned = - V->type != GGML_TYPE_F16 || ggml_webgpu_flash_attn_f16_vec4_aligned(V, storage_offset_alignment); + const bool k_float_vec4_aligned = + (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment); + const bool v_float_vec4_aligned = + (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); const bool k_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool v_vec_type_supported = - V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; + V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; const uint32_t k_vec_head_align = - K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); + (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); const uint32_t v_vec_head_align = - V->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(V->type); + (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(V->type); const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && - kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_f16_vec4_aligned && - v_f16_vec4_aligned; + kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned && + v_float_vec4_aligned; } static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx, @@ -4206,18 +4210,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } const auto & capabilities = ctx->webgpu_global_ctx->capabilities; const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; + + // subgroup matrix path requirements const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); - const bool f16_vec4_aligned = - ggml_webgpu_flash_attn_f16_vec4_aligned(src1, src2, storage_offset_alignment); + + // tile path requirements + const bool float_vec4_aligned = + ((src1->type != GGML_TYPE_F16 && src1->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) && + ((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment)); + const uint32_t k_tile_head_align = + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src1->type); + const uint32_t v_tile_head_align = + (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src2->type); + const bool tile_kv_head_dims_aligned = + src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0; const bool tile_can_dispatch_all_q_rows = capabilities.limits.maxComputeInvocationsPerWorkgroup >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && - src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && f16_vec4_aligned && - (src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + float_vec4_aligned && tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows; + if (!use_subgroup_matrix && !use_tile) { supports_op = false; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 7410a53332be..d27e5b5f0885 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -41,82 +41,6 @@ enable chromium_experimental_subgroup_matrix; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. #define KV_BLOCKS (KV_TILE / SG_MAT_N) -// Quantization constants/helpers -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread -#if defined(K_Q4_0) -#define K_NQ 16 -#define K_F16_PER_BLOCK 9 -#define K_BLOCK_SIZE_BYTES 18u -#define K_BYTES_PER_THREAD 8u -#define K_BYTES_PER_INNER_LOOP 4u -#elif defined(K_Q8_0) -#define K_NQ 16 -#define K_F16_PER_BLOCK 17 -#define K_BLOCK_SIZE_BYTES 34u -#define K_BYTES_PER_THREAD 16u -#define K_BYTES_PER_INNER_LOOP 4u -#endif - -#if defined(V_Q4_0) -#define V_NQ 16 -#define V_F16_PER_BLOCK 9 -#define V_BLOCK_SIZE_BYTES 18u -#define V_BYTES_PER_THREAD 8u -#define V_BYTES_PER_INNER_LOOP 4u -#elif defined(V_Q8_0) -#define V_NQ 16 -#define V_F16_PER_BLOCK 17 -#define V_BLOCK_SIZE_BYTES 34u -#define V_BYTES_PER_THREAD 16u -#define V_BYTES_PER_INNER_LOOP 4u -#endif - -#if defined(K_Q4_0) || defined(K_Q8_0) -fn load_k_u16_at(byte_offset: u32) -> u32 { - let word = K[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_k_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = K[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = K[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} -#endif - -#if defined(V_Q4_0) || defined(V_Q8_0) -fn load_v_u16_at(byte_offset: u32) -> u32 { - let word = V[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_v_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = V[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = V[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} -#endif - -fn f16_from_u16(bits: u32) -> f16 { - let packed = unpack2x16float(bits); - return f16(packed[0]); -} - struct Params { offset_q: u32, offset_k: u32, @@ -263,7 +187,7 @@ fn load_kx4(buf: ptr>, read_write>, scalar_index: u3 #define QUANT_SHMEM kv_shmem #define QUANT_OUT_TYPE f16 #include "quant_inner_loops.tmpl" -#include "flash_attn_quant_blocks.tmpl" +#include "flash_attn_quant_staging.tmpl" #endif @compute @workgroup_size(WG_SIZE) @@ -335,6 +259,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); // clear inter_shmem to ensure zero-initialized accumulators for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { inter_shmem[elem_idx] = 0.0; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl deleted file mode 100644 index cc43f9cc7b28..000000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_blocks.tmpl +++ /dev/null @@ -1,91 +0,0 @@ -#define LOAD_K_Q4_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ - let k_row = blck_idx / BLOCKS_K; \ - let global_k_row = kv_tile + k_row; \ - let block_k = blck_idx % BLOCKS_K; \ - let row_offset = k_row * HEAD_DIM_QK; \ - \ - if (global_k_row < params.seq_len_kv) { \ - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ - let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ - let q_packed = load_k_u32_at(q_byte_offset); \ - dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ - } \ - } \ -} - -#define LOAD_K_Q8_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * K_NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ - let k_row = blck_idx / BLOCKS_K; \ - let global_k_row = kv_tile + k_row; \ - let block_k = blck_idx % BLOCKS_K; \ - let row_offset = k_row * HEAD_DIM_QK; \ - \ - if (global_k_row < params.seq_len_kv) { \ - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ - let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ - let q_packed = load_k_u32_at(q_byte_offset); \ - dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ - } \ - } \ -} - -#define LOAD_V_Q4_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ - let v_row = blck_idx / BLOCKS_V; \ - let global_v_row = kv_tile + v_row; \ - let block_k = blck_idx % BLOCKS_V; \ - let row_offset = v_row * HEAD_DIM_V; \ - \ - if (global_v_row < params.seq_len_kv) { \ - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ - let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ - let q_packed = load_v_u32_at(q_byte_offset); \ - dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ - } \ - } \ -} - -#define LOAD_V_Q8_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * V_NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ - let v_row = blck_idx / BLOCKS_V; \ - let global_v_row = kv_tile + v_row; \ - let block_k = blck_idx % BLOCKS_V; \ - let row_offset = v_row * HEAD_DIM_V; \ - \ - if (global_v_row < params.seq_len_kv) { \ - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ - let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ - let q_packed = load_v_u32_at(q_byte_offset); \ - dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ - } \ - } \ -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl new file mode 100644 index 000000000000..d46e2e79ec9c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl @@ -0,0 +1,150 @@ +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) + +#if defined(K_Q4_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 18u +#define K_BYTES_PER_THREAD 8u +#define K_BYTES_PER_INNER_LOOP 4u +#elif defined(K_Q8_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 34u +#define K_BYTES_PER_THREAD 16u +#define K_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(V_Q4_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 18u +#define V_BYTES_PER_THREAD 8u +#define V_BYTES_PER_INNER_LOOP 4u +#elif defined(V_Q8_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 34u +#define V_BYTES_PER_THREAD 16u +#define V_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} + +#define LOAD_K_Q4_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ + let k_row = blck_idx / BLOCKS_K; \ + let global_k_row = kv_tile + k_row; \ + let block_k = blck_idx % BLOCKS_K; \ + let row_offset = k_row * HEAD_DIM_QK; \ + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ + let q_packed = load_k_u32_at(q_byte_offset); \ + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ + } \ +} + +#define LOAD_K_Q8_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ + let k_row = blck_idx / BLOCKS_K; \ + let global_k_row = kv_tile + k_row; \ + let block_k = blck_idx % BLOCKS_K; \ + let row_offset = k_row * HEAD_DIM_QK; \ + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ + let q_packed = load_k_u32_at(q_byte_offset); \ + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ + } \ +} + +#define LOAD_V_Q4_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ + let v_row = blck_idx / BLOCKS_V; \ + let global_v_row = kv_tile + v_row; \ + let block_k = blck_idx % BLOCKS_V; \ + let row_offset = v_row * HEAD_DIM_V; \ + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ + let q_packed = load_v_u32_at(q_byte_offset); \ + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ + } \ +} + +#define LOAD_V_Q8_0_TILE_BLOCK \ +for (var elem_idx = local_id.x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ + let blck_idx = elem_idx / BLOCK_SIZE; \ + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ + let v_row = blck_idx / BLOCKS_V; \ + let global_v_row = kv_tile + v_row; \ + let block_k = blck_idx % BLOCKS_V; \ + let row_offset = v_row * HEAD_DIM_V; \ + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ + let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ + let q_packed = load_v_u32_at(q_byte_offset); \ + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ + } \ +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index dfae83bcb481..86c796faf1e5 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -1,6 +1,9 @@ enable f16; enable subgroups; +#define BYTE_HELPERS +#include "common_decls.tmpl" + #ifdef Q_F16 #define Q_TYPE f16 #else @@ -9,12 +12,16 @@ enable subgroups; #ifdef K_F32 #define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 #else #define K_TYPE f16 #endif #ifdef V_F32 #define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else #define V_TYPE f16 #endif @@ -27,7 +34,6 @@ enable subgroups; #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 -#define KV_STAGE_STRIDE 64 #define Q_TILE 4 #define KV_TILE 64 #define WG_SIZE 128 @@ -70,12 +76,24 @@ struct Params { @group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var K: array; +#else @group(0) @binding(1) var K: array>; +#endif #define V K #else +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var K: array; +#else @group(0) @binding(1) var K: array>; +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var V: array; +#else @group(0) @binding(2) var V: array>; #endif +#endif #if defined(MASK) && defined(SINKS) #ifdef KV_OVERLAP @@ -127,11 +145,17 @@ const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var q_shmem: array; -var kv_shmem: array; +var kv_shmem: array; var p_shmem: array; +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, @@ -212,18 +236,24 @@ fn main(@builtin(workgroup_id) wg_id: vec3, local_scores[slot] = FLOAT_MIN; } + #if defined(K_Q4_0) + LOAD_K_Q4_0_TILE_BLOCK +#elif defined(K_Q8_0) + LOAD_K_Q8_0_TILE_BLOCK +#else for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { let kv_local = vec_idx_local / Q_CHUNKS; let chunk = vec_idx_local % Q_CHUNKS; let global_k_row = kv_tile + kv_local; let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; kv_shmem[kv_off + 0u] = f16(k4.x); kv_shmem[kv_off + 1u] = f16(k4.y); kv_shmem[kv_off + 2u] = f16(k4.z); kv_shmem[kv_off + 3u] = f16(k4.w); } +#endif workgroupBarrier(); @@ -244,7 +274,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; let kv = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], @@ -284,18 +314,24 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); + #if defined(V_Q4_0) + LOAD_V_Q4_0_TILE_BLOCK +#elif defined(V_Q8_0) + LOAD_V_Q8_0_TILE_BLOCK +#else for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { let kv_local = vec_idx_local / V_CHUNKS; let chunk = vec_idx_local % V_CHUNKS; let global_v_row = kv_tile + kv_local; let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; kv_shmem[kv_off + 0u] = f16(v4.x); kv_shmem[kv_off + 1u] = f16(v4.y); kv_shmem[kv_off + 2u] = f16(v4.z); kv_shmem[kv_off + 3u] = f16(v4.w); } +#endif workgroupBarrier(); @@ -313,7 +349,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var acc = out_regs[reg_idx]; for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { let p = f32(p_shmem[subgroup_p_offset + kv_local]); - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; let v4 = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index d3efebac5d20..e229c4904cfe 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -45,80 +45,6 @@ enable subgroups; #define KV_BLOCKS (KV_TILE / KV_GRANULARITY) -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -#if defined(K_Q4_0) -#define K_NQ 16 -#define K_F16_PER_BLOCK 9 -#define K_BLOCK_SIZE_BYTES 18u -#define K_BYTES_PER_THREAD 8u -#define K_BYTES_PER_INNER_LOOP 4u -#elif defined(K_Q8_0) -#define K_NQ 16 -#define K_F16_PER_BLOCK 17 -#define K_BLOCK_SIZE_BYTES 34u -#define K_BYTES_PER_THREAD 16u -#define K_BYTES_PER_INNER_LOOP 4u -#endif - -#if defined(V_Q4_0) -#define V_NQ 16 -#define V_F16_PER_BLOCK 9 -#define V_BLOCK_SIZE_BYTES 18u -#define V_BYTES_PER_THREAD 8u -#define V_BYTES_PER_INNER_LOOP 4u -#elif defined(V_Q8_0) -#define V_NQ 16 -#define V_F16_PER_BLOCK 17 -#define V_BLOCK_SIZE_BYTES 34u -#define V_BYTES_PER_THREAD 16u -#define V_BYTES_PER_INNER_LOOP 4u -#endif - -#if defined(K_Q4_0) || defined(K_Q8_0) -fn load_k_u16_at(byte_offset: u32) -> u32 { - let word = K[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_k_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = K[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = K[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} -#endif - -#if defined(V_Q4_0) || defined(V_Q8_0) -fn load_v_u16_at(byte_offset: u32) -> u32 { - let word = V[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_v_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = V[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = V[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} -#endif - -fn f16_from_u16(bits: u32) -> f16 { - let packed = unpack2x16float(bits); - return f16(packed[0]); -} - struct Params { offset_q: u32, offset_k: u32, @@ -313,7 +239,7 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) #define QUANT_SHMEM kv_shmem #define QUANT_OUT_TYPE f32 #include "quant_inner_loops.tmpl" -#include "flash_attn_quant_blocks.tmpl" +#include "flash_attn_quant_staging.tmpl" #endif @compute @workgroup_size(WG_SIZE) @@ -380,6 +306,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); #ifdef BLK let q_blk = q_row_start; let kv_blk = kv_tile / KV_TILE; From ffbd86f059c376ba7abb15cce160a97121e6e1ce Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 28 May 2026 11:08:10 -0700 Subject: [PATCH 6/7] formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 133 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 53 +- ggml/src/ggml-webgpu/pre_wgsl.hpp | 1432 +++++++++-------- 3 files changed, 808 insertions(+), 810 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3f9b8155f3c7..f51d5f010d64 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -97,6 +97,7 @@ struct ggml_webgpu_shader_lib_context { uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; bool supports_dot_product = false; + std::string vendor; }; @@ -560,14 +561,13 @@ struct ggml_webgpu_flash_attn_common_pipeline_key { bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const { return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type && - dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && - kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; } }; -inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, +inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, const ggml_webgpu_flash_attn_common_pipeline_key & key) { ggml_webgpu_hash_combine(seed, key.q_type); ggml_webgpu_hash_combine(seed, key.k_type); @@ -585,9 +585,7 @@ inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, struct ggml_webgpu_flash_attn_vec_pipeline_key { ggml_webgpu_flash_attn_common_pipeline_key common; - bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { - return common == other.common; - } + bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; } }; struct ggml_webgpu_flash_attn_vec_pipeline_key_hash { @@ -600,7 +598,7 @@ struct ggml_webgpu_flash_attn_vec_pipeline_key_hash { struct ggml_webgpu_flash_attn_pipeline_key { ggml_webgpu_flash_attn_common_pipeline_key common; - bool use_sg_matrix; + bool use_sg_matrix; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return common == other.common && use_sg_matrix == other.use_sg_matrix; @@ -637,9 +635,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; } -inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, - size_t storage_offset_alignment) { - const size_t alignment = std::max(1u, storage_offset_alignment); +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { + const size_t alignment = std::max(1u, storage_offset_alignment); const uint32_t offset_elems = (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; @@ -652,16 +649,14 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); } -inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, - const ggml_tensor * K, - uint32_t kv_direct_align) { +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, const ggml_tensor * K, uint32_t kv_direct_align) { return K->type == GGML_TYPE_F16 && (Q->ne[0] % std::max(1u, kv_direct_align) == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); } inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key( const ggml_webgpu_shader_lib_context & context, - uint32_t kv_direct_align) { + uint32_t kv_direct_align) { ggml_webgpu_flash_attn_common_pipeline_key key = {}; key.q_type = context.src0->type; key.k_type = context.src1->type; @@ -669,11 +664,11 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, kv_direct_align); - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = context.src3 != nullptr; - key.has_sinks = context.src4 != nullptr; - key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; return key; } @@ -840,8 +835,6 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } - - inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, uint32_t q_tile, uint32_t kv_granularity, @@ -884,11 +877,11 @@ inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_byt return kv_tile; } -inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, - uint32_t sg_mat_k, - uint32_t sg_mat_n, - const ggml_tensor * Q, - const ggml_tensor * V) { +inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, + uint32_t sg_mat_k, + uint32_t sg_mat_n, + const ggml_tensor * Q, + const ggml_tensor * V) { return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0; } @@ -1166,9 +1159,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_flash_attn_vec_pipeline_key_hash> flash_attn_vec_pipelines; - std::unordered_map + std::unordered_map flash_attn_pipelines; std::unordered_map flash_attn_blk_pipelines; std::unordered_map - mul_mat_vec_pipelines; // fast mat-vec (n==1) + mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) std::unordered_map quantize_q8_pipelines; std::unordered_map mul_mat_id_gather_pipelines; // key is fixed @@ -1863,11 +1854,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_mmvq = + 1 : + 0; + key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); @@ -1999,11 +1990,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_subgroup_matrix = context.supports_subgroup_matrix; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -2176,10 +2167,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2299,10 +2290,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_vec_pipelines.find(key); if (it != mul_mat_id_vec_pipelines.end()) { @@ -2696,30 +2687,28 @@ class ggml_webgpu_shader_lib { const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); ggml_webgpu_flash_attn_decisions decisions = {}; - decisions.use_sg_matrix = can_use_subgroup_matrix; - decisions.q_tile = - decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + decisions.use_sg_matrix = can_use_subgroup_matrix; + decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; ggml_webgpu_flash_attn_pipeline_key key = {}; - key.common = ggml_webgpu_flash_attn_make_common_pipeline_key( - context, decisions.use_sg_matrix ? context.sg_mat_k : 1u); + key.common = + ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u); key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct; key.use_sg_matrix = decisions.use_sg_matrix; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( - context.wg_mem_limit_bytes, decisions.q_tile, - decisions.use_sg_matrix ? context.sg_mat_n : 1u, key.common.head_dim_qk, - key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u, + key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); GGML_ASSERT(max_kv_tile > 0); decisions.kv_tile = decisions.use_sg_matrix ? std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) : std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile); - decisions.wg_size = decisions.use_sg_matrix ? - std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) : - std::min(context.max_wg_size, - std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, - GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)); + decisions.wg_size = + decisions.use_sg_matrix ? + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) : + std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)); if (key.common.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); @@ -2734,9 +2723,8 @@ class ggml_webgpu_shader_lib { } std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile"; - std::vector defines = - ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile, decisions.kv_tile, - decisions.wg_size); + std::vector defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile, + decisions.kv_tile, decisions.wg_size); const char * shader_src = nullptr; if (!key.use_sg_matrix) { shader_src = wgsl_flash_attn_tile; @@ -2750,18 +2738,17 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - auto pipeline_decisions = std::make_shared(decisions); + auto pipeline_decisions = std::make_shared(decisions); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - pipeline.context = pipeline_decisions; + pipeline.context = pipeline_decisions; flash_attn_pipelines[key] = pipeline; return flash_attn_pipelines[key]; } webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_flash_attn_vec_pipeline_key key = {}; - key.common = ggml_webgpu_flash_attn_make_common_pipeline_key( - context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); auto it = flash_attn_vec_pipelines.find(key); if (it != flash_attn_vec_pipelines.end()) { @@ -2769,9 +2756,9 @@ class ggml_webgpu_shader_lib { } ggml_webgpu_flash_attn_vec_decisions decisions = {}; - decisions.kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( - context.wg_mem_limit_bytes, key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, - key.common.kv_direct); + decisions.kv_tile = + ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk, + key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); decisions.wg_size = context.max_subgroup_size; std::string variant = "flash_attn_vec"; @@ -2800,10 +2787,10 @@ class ggml_webgpu_shader_lib { } defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - auto pipeline_decisions = std::make_shared(decisions); + auto pipeline_decisions = std::make_shared(decisions); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); - pipeline.context = pipeline_decisions; + pipeline.context = pipeline_decisions; flash_attn_vec_pipelines[key] = pipeline; return flash_attn_vec_pipelines[key]; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e43834e0264d..550ffe38b5b3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1774,28 +1774,22 @@ static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & gl const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V) { - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - GGML_ASSERT(Q != nullptr); - const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const bool k_float_vec4_aligned = - (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) || - ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment); - const bool v_float_vec4_aligned = - (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) || - ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); - const bool k_vec_type_supported = + const bool k_float_vec4_aligned = (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment); + const bool v_float_vec4_aligned = (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); + const bool k_vec_type_supported = K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool v_vec_type_supported = V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; - const uint32_t k_vec_head_align = - (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(K->type); - const uint32_t v_vec_head_align = - (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(V->type); - const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; + const uint32_t k_vec_head_align = (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const uint32_t v_vec_head_align = (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(V->type); + const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned && @@ -4212,7 +4206,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; // subgroup matrix path requirements - const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); // tile path requirements @@ -4221,22 +4215,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) && ((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) || ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment)); - const uint32_t k_tile_head_align = - (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ? - GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(src1->type); - const uint32_t v_tile_head_align = - (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ? - GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(src2->type); - const bool tile_kv_head_dims_aligned = + const uint32_t k_tile_head_align = (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src1->type); + const uint32_t v_tile_head_align = (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src2->type); + const bool tile_kv_head_dims_aligned = src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0; const bool tile_can_dispatch_all_q_rows = capabilities.limits.maxComputeInvocationsPerWorkgroup >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; - const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && - float_vec4_aligned && tile_kv_head_dims_aligned && - tile_can_dispatch_all_q_rows; + const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && float_vec4_aligned && + tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows; if (!use_subgroup_matrix && !use_tile) { supports_op = false; diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp index 702e6cda9391..fb41a961d745 100644 --- a/ggml/src/ggml-webgpu/pre_wgsl.hpp +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -17,772 +17,792 @@ namespace pre_wgsl { // Options //============================================================== struct Options { - std::string include_path = "."; - std::vector macros; + std::string include_path = "."; + std::vector macros; }; //============================================================== // Utility: trim //============================================================== -static std::string trim(const std::string &s) { - size_t a = 0; - while (a < s.size() && std::isspace((unsigned char)s[a])) - a++; - size_t b = s.size(); - while (b > a && std::isspace((unsigned char)s[b - 1])) - b--; - return s.substr(a, b - a); +static std::string trim(const std::string & s) { + size_t a = 0; + while (a < s.size() && std::isspace((unsigned char) s[a])) { + a++; + } + size_t b = s.size(); + while (b > a && std::isspace((unsigned char) s[b - 1])) { + b--; + } + return s.substr(a, b - a); } -static std::string trim_value(std::istream &is) { - std::ostringstream ss; - ss << is.rdbuf(); - return trim(ss.str()); +static std::string trim_value(std::istream & is) { + std::ostringstream ss; + ss << is.rdbuf(); + return trim(ss.str()); } static bool isIdentChar(char c) { - return std::isalnum(static_cast(c)) || c == '_'; + return std::isalnum(static_cast(c)) || c == '_'; } -static bool endsWithContinuation(const std::string &line) { - size_t i = line.size(); - while (i > 0 && std::isspace((unsigned char)line[i - 1])) - i--; - return i > 0 && line[i - 1] == '\\'; +static bool endsWithContinuation(const std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + return i > 0 && line[i - 1] == '\\'; } -static void stripContinuation(std::string &line) { - size_t i = line.size(); - while (i > 0 && std::isspace((unsigned char)line[i - 1])) - i--; - if (i > 0 && line[i - 1] == '\\') { - line.erase(i - 1); - } +static void stripContinuation(std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + if (i > 0 && line[i - 1] == '\\') { + line.erase(i - 1); + } } -static std::string expandMacrosRecursiveInternal( - const std::string &line, - const std::unordered_map ¯os, - std::unordered_set &visiting); - -static std::string -expandMacroValue(const std::string &name, - const std::unordered_map ¯os, - std::unordered_set &visiting) { - if (visiting.count(name)) - throw std::runtime_error("Recursive macro: " + name); - visiting.insert(name); - - auto it = macros.find(name); - if (it == macros.end()) { - visiting.erase(name); - return name; - } +static std::string expandMacrosRecursiveInternal(const std::string & line, + const std::unordered_map & macros, + std::unordered_set & visiting); - const std::string &value = it->second; - if (value.empty()) { - visiting.erase(name); - return ""; - } +static std::string expandMacroValue(const std::string & name, + const std::unordered_map & macros, + std::unordered_set & visiting) { + if (visiting.count(name)) { + throw std::runtime_error("Recursive macro: " + name); + } + visiting.insert(name); + + auto it = macros.find(name); + if (it == macros.end()) { + visiting.erase(name); + return name; + } + + const std::string & value = it->second; + if (value.empty()) { + visiting.erase(name); + return ""; + } - std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting); - visiting.erase(name); - return expanded; + std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting); + visiting.erase(name); + return expanded; } -static std::string expandMacrosRecursiveInternal( - const std::string &line, - const std::unordered_map ¯os, - std::unordered_set &visiting) { - std::string result; - result.reserve(line.size()); - - size_t i = 0; - while (i < line.size()) { - if (isIdentChar(line[i])) { - size_t start = i; - while (i < line.size() && isIdentChar(line[i])) { - i++; - } - std::string token = line.substr(start, i - start); - - auto it = macros.find(token); - if (it != macros.end()) { - result += expandMacroValue(token, macros, visiting); - } else { - result += token; - } - } else { - result += line[i]; - i++; - } - } - - return result; +static std::string expandMacrosRecursiveInternal(const std::string & line, + const std::unordered_map & macros, + std::unordered_set & visiting) { + std::string result; + result.reserve(line.size()); + + size_t i = 0; + while (i < line.size()) { + if (isIdentChar(line[i])) { + size_t start = i; + while (i < line.size() && isIdentChar(line[i])) { + i++; + } + std::string token = line.substr(start, i - start); + + auto it = macros.find(token); + if (it != macros.end()) { + result += expandMacroValue(token, macros, visiting); + } else { + result += token; + } + } else { + result += line[i]; + i++; + } + } + + return result; } -static std::string expandMacrosRecursive( - const std::string &line, - const std::unordered_map ¯os) { - std::unordered_set visiting; - return expandMacrosRecursiveInternal(line, macros, visiting); +static std::string expandMacrosRecursive(const std::string & line, + const std::unordered_map & macros) { + std::unordered_set visiting; + return expandMacrosRecursiveInternal(line, macros, visiting); } //============================================================== // Tokenizer for expressions in #if/#elif //============================================================== class ExprLexer { -public: - enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN }; + public: + enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN }; - struct Tok { - Kind kind; - std::string text; - }; + struct Tok { + Kind kind; + std::string text; + }; - explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} + explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} - Tok next() { - skipWS(); - if (pos >= src.size()) - return {END, ""}; + Tok next() { + skipWS(); + if (pos >= src.size()) { + return { END, "" }; + } - char c = src[pos]; + char c = src[pos]; - // number - if (std::isdigit((unsigned char)c)) { - size_t start = pos; - while (pos < src.size() && std::isdigit((unsigned char)src[pos])) - pos++; - return {NUMBER, std::string(src.substr(start, pos - start))}; - } + // number + if (std::isdigit((unsigned char) c)) { + size_t start = pos; + while (pos < src.size() && std::isdigit((unsigned char) src[pos])) { + pos++; + } + return { NUMBER, std::string(src.substr(start, pos - start)) }; + } - // identifier - if (std::isalpha((unsigned char)c) || c == '_') { - size_t start = pos; - while (pos < src.size() && - (std::isalnum((unsigned char)src[pos]) || src[pos] == '_')) - pos++; - return {IDENT, std::string(src.substr(start, pos - start))}; - } + // identifier + if (std::isalpha((unsigned char) c) || c == '_') { + size_t start = pos; + while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) { + pos++; + } + return { IDENT, std::string(src.substr(start, pos - start)) }; + } - if (c == '(') { - pos++; - return {LPAREN, "("}; - } - if (c == ')') { - pos++; - return {RPAREN, ")"}; - } + if (c == '(') { + pos++; + return { LPAREN, "(" }; + } + if (c == ')') { + pos++; + return { RPAREN, ")" }; + } - // multi-char operators - static const char *two_ops[] = { - "==", "!=", "<=", ">=", "&&", "||", "<<", ">>"}; - for (auto op : two_ops) { - if (src.substr(pos, 2) == op) { - pos += 2; - return {OP, std::string(op)}; - } - } + // multi-char operators + static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" }; + for (auto op : two_ops) { + if (src.substr(pos, 2) == op) { + pos += 2; + return { OP, std::string(op) }; + } + } - // single-char operators - if (std::string("+-*/%<>!").find(c) != std::string::npos) { - pos++; - return {OP, std::string(1, c)}; - } + // single-char operators + if (std::string("+-*/%<>!").find(c) != std::string::npos) { + pos++; + return { OP, std::string(1, c) }; + } - // unexpected - pos++; - return {END, ""}; - } + // unexpected + pos++; + return { END, "" }; + } -private: - std::string_view src; - size_t pos; + private: + std::string_view src; + size_t pos; - void skipWS() { - while (pos < src.size() && std::isspace((unsigned char)src[pos])) - pos++; - } + void skipWS() { + while (pos < src.size() && std::isspace((unsigned char) src[pos])) { + pos++; + } + } }; //============================================================== // Expression Parser (recursive descent) //============================================================== class ExprParser { -public: - ExprParser(std::string_view expr, - const std::unordered_map ¯os, - std::unordered_set &visiting) - : lex(expr), macros(macros), visiting(visiting) { - advance(); - } - - int parse() { return parseLogicalOr(); } - -private: - ExprLexer lex; - ExprLexer::Tok tok; - const std::unordered_map ¯os; - std::unordered_set &visiting; - - void advance() { tok = lex.next(); } - - bool acceptOp(const std::string &s) { - if (tok.kind == ExprLexer::OP && tok.text == s) { - advance(); - return true; - } - return false; - } - - bool acceptKind(ExprLexer::Kind k) { - if (tok.kind == k) { - advance(); - return true; - } - return false; - } - - int parseLogicalOr() { - int v = parseLogicalAnd(); - while (acceptOp("||")) { - int rhs = parseLogicalAnd(); - v = (v || rhs); - } - return v; - } - - int parseLogicalAnd() { - int v = parseEquality(); - while (acceptOp("&&")) { - int rhs = parseEquality(); - v = (v && rhs); - } - return v; - } - - int parseEquality() { - int v = parseRelational(); - for (;;) { - if (acceptOp("==")) { - int rhs = parseRelational(); - v = (v == rhs); - } else if (acceptOp("!=")) { - int rhs = parseRelational(); - v = (v != rhs); - } else - break; - } - return v; - } - - int parseRelational() { - int v = parseShift(); - for (;;) { - if (acceptOp("<")) { - int rhs = parseShift(); - v = (v < rhs); - } else if (acceptOp(">")) { - int rhs = parseShift(); - v = (v > rhs); - } else if (acceptOp("<=")) { - int rhs = parseShift(); - v = (v <= rhs); - } else if (acceptOp(">=")) { - int rhs = parseShift(); - v = (v >= rhs); - } else - break; - } - return v; - } - - int parseShift() { - int v = parseAdd(); - for (;;) { - if (acceptOp("<<")) { - int rhs = parseAdd(); - v = (v << rhs); - } else if (acceptOp(">>")) { - int rhs = parseAdd(); - v = (v >> rhs); - } else - break; - } - return v; - } - - int parseAdd() { - int v = parseMult(); - for (;;) { - if (acceptOp("+")) { - int rhs = parseMult(); - v = (v + rhs); - } else if (acceptOp("-")) { - int rhs = parseMult(); - v = (v - rhs); - } else - break; - } - return v; - } - - int parseMult() { - int v = parseUnary(); - for (;;) { - if (acceptOp("*")) { - int rhs = parseUnary(); - v = (v * rhs); - } else if (acceptOp("/")) { - int rhs = parseUnary(); - v = (rhs == 0 ? 0 : v / rhs); - } else if (acceptOp("%")) { - int rhs = parseUnary(); - v = (rhs == 0 ? 0 : v % rhs); - } else - break; - } - return v; - } - - int parseUnary() { - if (acceptOp("!")) - return !parseUnary(); - if (acceptOp("-")) - return -parseUnary(); - if (acceptOp("+")) - return +parseUnary(); - return parsePrimary(); - } - - int parsePrimary() { - // '(' expr ')' - if (acceptKind(ExprLexer::LPAREN)) { - int v = parse(); - if (!acceptKind(ExprLexer::RPAREN)) - throw std::runtime_error("missing ')'"); - return v; - } - - // number - if (tok.kind == ExprLexer::NUMBER) { - int v = std::stoi(tok.text); - advance(); - return v; - } - - // defined(identifier) - if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { - advance(); - if (acceptKind(ExprLexer::LPAREN)) { - if (tok.kind != ExprLexer::IDENT) - throw std::runtime_error("expected identifier in defined()"); - std::string name = tok.text; + public: + ExprParser(std::string_view expr, + const std::unordered_map & macros, + std::unordered_set & visiting) : + lex(expr), + macros(macros), + visiting(visiting) { advance(); - if (!acceptKind(ExprLexer::RPAREN)) - throw std::runtime_error("missing ) in defined()"); - return macros.count(name) ? 1 : 0; - } else { - // defined NAME - if (tok.kind != ExprLexer::IDENT) - throw std::runtime_error("expected identifier in defined NAME"); - std::string name = tok.text; - advance(); - return macros.count(name) ? 1 : 0; - } } - // identifier -> treat as integer, if defined use its value else 0 - if (tok.kind == ExprLexer::IDENT) { - std::string name = tok.text; - advance(); - auto it = macros.find(name); - if (it == macros.end()) - return 0; - if (it->second.empty()) - return 1; - return evalMacroExpression(name, it->second); + int parse() { return parseLogicalOr(); } + + private: + ExprLexer lex; + ExprLexer::Tok tok; + const std::unordered_map & macros; + std::unordered_set & visiting; + + void advance() { tok = lex.next(); } + + bool acceptOp(const std::string & s) { + if (tok.kind == ExprLexer::OP && tok.text == s) { + advance(); + return true; + } + return false; + } + + bool acceptKind(ExprLexer::Kind k) { + if (tok.kind == k) { + advance(); + return true; + } + return false; } - // unexpected - return 0; - } + int parseLogicalOr() { + int v = parseLogicalAnd(); + while (acceptOp("||")) { + int rhs = parseLogicalAnd(); + v = (v || rhs); + } + return v; + } - int evalMacroExpression(const std::string &name, const std::string &value) { - if (visiting.count(name)) - throw std::runtime_error("Recursive macro: " + name); + int parseLogicalAnd() { + int v = parseEquality(); + while (acceptOp("&&")) { + int rhs = parseEquality(); + v = (v && rhs); + } + return v; + } + + int parseEquality() { + int v = parseRelational(); + for (;;) { + if (acceptOp("==")) { + int rhs = parseRelational(); + v = (v == rhs); + } else if (acceptOp("!=")) { + int rhs = parseRelational(); + v = (v != rhs); + } else { + break; + } + } + return v; + } + + int parseRelational() { + int v = parseShift(); + for (;;) { + if (acceptOp("<")) { + int rhs = parseShift(); + v = (v < rhs); + } else if (acceptOp(">")) { + int rhs = parseShift(); + v = (v > rhs); + } else if (acceptOp("<=")) { + int rhs = parseShift(); + v = (v <= rhs); + } else if (acceptOp(">=")) { + int rhs = parseShift(); + v = (v >= rhs); + } else { + break; + } + } + return v; + } + + int parseShift() { + int v = parseAdd(); + for (;;) { + if (acceptOp("<<")) { + int rhs = parseAdd(); + v = (v << rhs); + } else if (acceptOp(">>")) { + int rhs = parseAdd(); + v = (v >> rhs); + } else { + break; + } + } + return v; + } + + int parseAdd() { + int v = parseMult(); + for (;;) { + if (acceptOp("+")) { + int rhs = parseMult(); + v = (v + rhs); + } else if (acceptOp("-")) { + int rhs = parseMult(); + v = (v - rhs); + } else { + break; + } + } + return v; + } + + int parseMult() { + int v = parseUnary(); + for (;;) { + if (acceptOp("*")) { + int rhs = parseUnary(); + v = (v * rhs); + } else if (acceptOp("/")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v / rhs); + } else if (acceptOp("%")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v % rhs); + } else { + break; + } + } + return v; + } - visiting.insert(name); - ExprParser ep(value, macros, visiting); - int v = ep.parse(); - visiting.erase(name); - return v; - } + int parseUnary() { + if (acceptOp("!")) { + return !parseUnary(); + } + if (acceptOp("-")) { + return -parseUnary(); + } + if (acceptOp("+")) { + return +parseUnary(); + } + return parsePrimary(); + } + + int parsePrimary() { + // '(' expr ')' + if (acceptKind(ExprLexer::LPAREN)) { + int v = parse(); + if (!acceptKind(ExprLexer::RPAREN)) { + throw std::runtime_error("missing ')'"); + } + return v; + } + + // number + if (tok.kind == ExprLexer::NUMBER) { + int v = std::stoi(tok.text); + advance(); + return v; + } + + // defined(identifier) + if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { + advance(); + if (acceptKind(ExprLexer::LPAREN)) { + if (tok.kind != ExprLexer::IDENT) { + throw std::runtime_error("expected identifier in defined()"); + } + std::string name = tok.text; + advance(); + if (!acceptKind(ExprLexer::RPAREN)) { + throw std::runtime_error("missing ) in defined()"); + } + return macros.count(name) ? 1 : 0; + } else { + // defined NAME + if (tok.kind != ExprLexer::IDENT) { + throw std::runtime_error("expected identifier in defined NAME"); + } + std::string name = tok.text; + advance(); + return macros.count(name) ? 1 : 0; + } + } + + // identifier -> treat as integer, if defined use its value else 0 + if (tok.kind == ExprLexer::IDENT) { + std::string name = tok.text; + advance(); + auto it = macros.find(name); + if (it == macros.end()) { + return 0; + } + if (it->second.empty()) { + return 1; + } + return evalMacroExpression(name, it->second); + } + + // unexpected + return 0; + } + + int evalMacroExpression(const std::string & name, const std::string & value) { + if (visiting.count(name)) { + throw std::runtime_error("Recursive macro: " + name); + } + + visiting.insert(name); + ExprParser ep(value, macros, visiting); + int v = ep.parse(); + visiting.erase(name); + return v; + } }; //============================================================== // Preprocessor //============================================================== class Preprocessor { -public: - explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) { - // Treat empty include path as current directory - if (opts_.include_path.empty()) { - opts_.include_path = "."; - } - parseMacroDefinitions(opts_.macros); - } - - std::string - preprocess_file(const std::string &filename, - const std::vector &additional_macros = {}) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - buildMacros(additional_macros, macros, predefined); - - std::string result = processFile(filename, macros, predefined, - include_stack, DirectiveMode::All); - return result; - } - - std::string - preprocess(const std::string &contents, - const std::vector &additional_macros = {}) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - buildMacros(additional_macros, macros, predefined); - - std::string result = processString(contents, macros, predefined, - include_stack, DirectiveMode::All); - return result; - } - - std::string preprocess_includes_file(const std::string &filename) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - std::string result = - processFile(filename, macros, predefined, include_stack, - DirectiveMode::IncludesOnly); - return result; - } - - std::string preprocess_includes(const std::string &contents) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - std::string result = - processString(contents, macros, predefined, include_stack, - DirectiveMode::IncludesOnly); - return result; - } - -private: - Options opts_; - std::unordered_map global_macros; - - enum class DirectiveMode { All, IncludesOnly }; - - struct Cond { - bool parent_active; - bool active; - bool taken; - }; - - //---------------------------------------------------------- - // Parse macro definitions into global_macros - //---------------------------------------------------------- - void parseMacroDefinitions(const std::vector ¯o_defs) { - for (const auto &def : macro_defs) { - size_t eq_pos = def.find('='); - if (eq_pos != std::string::npos) { - // Format: NAME=VALUE - std::string name = trim(def.substr(0, eq_pos)); - std::string value = trim(def.substr(eq_pos + 1)); - global_macros[name] = value; - } else { - // Format: NAME - std::string name = trim(def); - global_macros[name] = ""; - } - } - } - - //---------------------------------------------------------- - // Build combined macro map and predefined set for a preprocessing operation - //---------------------------------------------------------- - void buildMacros(const std::vector &additional_macros, - std::unordered_map ¯os, - std::unordered_set &predefined) { - macros = global_macros; - predefined.clear(); - - for (const auto &[name, value] : global_macros) { - predefined.insert(name); - } - - for (const auto &def : additional_macros) { - size_t eq_pos = def.find('='); - std::string name, value; - if (eq_pos != std::string::npos) { - name = trim(def.substr(0, eq_pos)); - value = trim(def.substr(eq_pos + 1)); - } else { - name = trim(def); - value = ""; - } - - // Add to macros map (will override global if same name) - macros[name] = value; - predefined.insert(name); - } - } - - //---------------------------------------------------------- - // Helpers - //---------------------------------------------------------- - std::string loadFile(const std::string &fname) { - std::ifstream f(fname); - if (!f.is_open()) - throw std::runtime_error("Could not open file: " + fname); - std::stringstream ss; - ss << f.rdbuf(); - return ss.str(); - } - - bool condActive(const std::vector &cond) const { - if (cond.empty()) - return true; - return cond.back().active; - } - - //---------------------------------------------------------- - // Process a file - //---------------------------------------------------------- - std::string - processFile(const std::string &name, - std::unordered_map ¯os, - const std::unordered_set &predefined_macros, - std::unordered_set &include_stack, - DirectiveMode mode) { - if (include_stack.count(name)) - throw std::runtime_error("Recursive include: " + name); - - include_stack.insert(name); - std::string shader_code = loadFile(name); - std::string out = processString(shader_code, macros, predefined_macros, - include_stack, mode); - include_stack.erase(name); - return out; - } - - std::string - processIncludeFile(const std::string &fname, - std::unordered_map ¯os, - const std::unordered_set &predefined_macros, - std::unordered_set &include_stack, - DirectiveMode mode) { - std::string full_path = opts_.include_path + "/" + fname; - return processFile(full_path, macros, predefined_macros, include_stack, - mode); - } - - //---------------------------------------------------------- - // Process text - //---------------------------------------------------------- - std::string - processString(const std::string &shader_code, - std::unordered_map ¯os, - const std::unordered_set &predefined_macros, - std::unordered_set &include_stack, - DirectiveMode mode) { - std::vector cond; // Conditional stack for this shader - std::stringstream out; - std::istringstream in(shader_code); - std::string line; - - while (std::getline(in, line)) { - std::string logical = line; - std::string t = trim(logical); - if (!t.empty() && t[0] == '#') { - while (endsWithContinuation(logical)) { - stripContinuation(logical); - if (!std::getline(in, line)) - break; - logical += "\n"; - logical += line; - } - t = trim(logical); - } - - if (!t.empty() && t[0] == '#') { - bool handled = handleDirective(t, out, macros, predefined_macros, cond, - include_stack, mode); - if (mode == DirectiveMode::IncludesOnly && !handled) { - out << logical << "\n"; - } - } else { + public: + explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) { + // Treat empty include path as current directory + if (opts_.include_path.empty()) { + opts_.include_path = "."; + } + parseMacroDefinitions(opts_.macros); + } + + std::string preprocess_file(const std::string & filename, const std::vector & additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess(const std::string & contents, const std::vector & additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess_includes_file(const std::string & filename) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly); + return result; + } + + std::string preprocess_includes(const std::string & contents) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly); + return result; + } + + private: + Options opts_; + std::unordered_map global_macros; + + enum class DirectiveMode { All, IncludesOnly }; + + struct Cond { + bool parent_active; + bool active; + bool taken; + }; + + //---------------------------------------------------------- + // Parse macro definitions into global_macros + //---------------------------------------------------------- + void parseMacroDefinitions(const std::vector & macro_defs) { + for (const auto & def : macro_defs) { + size_t eq_pos = def.find('='); + if (eq_pos != std::string::npos) { + // Format: NAME=VALUE + std::string name = trim(def.substr(0, eq_pos)); + std::string value = trim(def.substr(eq_pos + 1)); + global_macros[name] = value; + } else { + // Format: NAME + std::string name = trim(def); + global_macros[name] = ""; + } + } + } + + //---------------------------------------------------------- + // Build combined macro map and predefined set for a preprocessing operation + //---------------------------------------------------------- + void buildMacros(const std::vector & additional_macros, + std::unordered_map & macros, + std::unordered_set & predefined) { + macros = global_macros; + predefined.clear(); + + for (const auto & [name, value] : global_macros) { + predefined.insert(name); + } + + for (const auto & def : additional_macros) { + size_t eq_pos = def.find('='); + std::string name, value; + if (eq_pos != std::string::npos) { + name = trim(def.substr(0, eq_pos)); + value = trim(def.substr(eq_pos + 1)); + } else { + name = trim(def); + value = ""; + } + + // Add to macros map (will override global if same name) + macros[name] = value; + predefined.insert(name); + } + } + + //---------------------------------------------------------- + // Helpers + //---------------------------------------------------------- + std::string loadFile(const std::string & fname) { + std::ifstream f(fname); + if (!f.is_open()) { + throw std::runtime_error("Could not open file: " + fname); + } + std::stringstream ss; + ss << f.rdbuf(); + return ss.str(); + } + + bool condActive(const std::vector & cond) const { + if (cond.empty()) { + return true; + } + return cond.back().active; + } + + //---------------------------------------------------------- + // Process a file + //---------------------------------------------------------- + std::string processFile(const std::string & name, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + if (include_stack.count(name)) { + throw std::runtime_error("Recursive include: " + name); + } + + include_stack.insert(name); + std::string shader_code = loadFile(name); + std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode); + include_stack.erase(name); + return out; + } + + std::string processIncludeFile(const std::string & fname, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + std::string full_path = opts_.include_path + "/" + fname; + return processFile(full_path, macros, predefined_macros, include_stack, mode); + } + + //---------------------------------------------------------- + // Process text + //---------------------------------------------------------- + std::string processString(const std::string & shader_code, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::unordered_set & include_stack, + DirectiveMode mode) { + std::vector cond; // Conditional stack for this shader + std::stringstream out; + std::istringstream in(shader_code); + std::string line; + + while (std::getline(in, line)) { + std::string logical = line; + std::string t = trim(logical); + if (!t.empty() && t[0] == '#') { + while (endsWithContinuation(logical)) { + stripContinuation(logical); + if (!std::getline(in, line)) { + break; + } + logical += "\n"; + logical += line; + } + t = trim(logical); + } + + if (!t.empty() && t[0] == '#') { + bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); + if (mode == DirectiveMode::IncludesOnly && !handled) { + out << logical << "\n"; + } + } else { + if (mode == DirectiveMode::IncludesOnly) { + out << logical << "\n"; + } else if (condActive(cond)) { + // Expand macros in the line before outputting + std::string expanded = expandMacrosRecursive(logical, macros); + out << expanded << "\n"; + } + } + } + + if (mode == DirectiveMode::All && !cond.empty()) { + throw std::runtime_error("Unclosed #if directive"); + } + + return out.str(); + } + + //---------------------------------------------------------- + // Directive handler + //---------------------------------------------------------- + bool handleDirective(const std::string & t, + std::stringstream & out, + std::unordered_map & macros, + const std::unordered_set & predefined_macros, + std::vector & cond, + std::unordered_set & include_stack, + DirectiveMode mode) { + // split into tokens + std::string body = t.substr(1); + std::istringstream iss(body); + std::string cmd; + iss >> cmd; + + if (cmd == "include") { + if (mode == DirectiveMode::All && !condActive(cond)) { + return true; + } + std::string file; + iss >> file; + if (file.size() >= 2 && file.front() == '"' && file.back() == '"') { + file = file.substr(1, file.size() - 2); + } + out << processIncludeFile(file, macros, predefined_macros, include_stack, mode); + return true; + } + if (mode == DirectiveMode::IncludesOnly) { - out << logical << "\n"; - } else if (condActive(cond)) { - // Expand macros in the line before outputting - std::string expanded = expandMacrosRecursive(logical, macros); - out << expanded << "\n"; - } - } - } - - if (mode == DirectiveMode::All && !cond.empty()) - throw std::runtime_error("Unclosed #if directive"); - - return out.str(); - } - - //---------------------------------------------------------- - // Directive handler - //---------------------------------------------------------- - bool handleDirective(const std::string &t, std::stringstream &out, - std::unordered_map ¯os, - const std::unordered_set &predefined_macros, - std::vector &cond, - std::unordered_set &include_stack, - DirectiveMode mode) { - // split into tokens - std::string body = t.substr(1); - std::istringstream iss(body); - std::string cmd; - iss >> cmd; - - if (cmd == "include") { - if (mode == DirectiveMode::All && !condActive(cond)) - return true; - std::string file; - iss >> file; - if (file.size() >= 2 && file.front() == '"' && file.back() == '"') - file = file.substr(1, file.size() - 2); - out << processIncludeFile(file, macros, predefined_macros, include_stack, - mode); - return true; - } - - if (mode == DirectiveMode::IncludesOnly) - return false; - - if (cmd == "define") { - if (!condActive(cond)) - return true; - std::string name; - iss >> name; - // Don't override predefined macros from options - if (predefined_macros.count(name)) - return true; - std::string value = trim_value(iss); - macros[name] = value; - return true; - } - - if (cmd == "undef") { - if (!condActive(cond)) - return true; - std::string name; - iss >> name; - // Don't undef predefined macros from options - if (predefined_macros.count(name)) - return true; - macros.erase(name); - return true; - } - - if (cmd == "ifdef") { - std::string name; - iss >> name; - bool p = condActive(cond); - bool v = macros.count(name); - cond.push_back({p, p && v, p && v}); - return true; - } - - if (cmd == "ifndef") { - std::string name; - iss >> name; - bool p = condActive(cond); - bool v = !macros.count(name); - cond.push_back({p, p && v, p && v}); - return true; - } - - if (cmd == "if") { - std::string expr = trim_value(iss); - bool p = condActive(cond); - bool v = false; - if (p) { - std::unordered_set visiting; - ExprParser ep(expr, macros, visiting); - v = ep.parse() != 0; - } - cond.push_back({p, p && v, p && v}); - return true; - } - - if (cmd == "elif") { - std::string expr = trim_value(iss); - - if (cond.empty()) - throw std::runtime_error("#elif without #if"); - - Cond &c = cond.back(); - if (!c.parent_active) { - c.active = false; - return true; - } - - if (c.taken) { - c.active = false; - return true; - } - - std::unordered_set visiting; - ExprParser ep(expr, macros, visiting); - bool v = ep.parse() != 0; - c.active = v; - if (v) - c.taken = true; - return true; - } - - if (cmd == "else") { - if (cond.empty()) - throw std::runtime_error("#else without #if"); - - Cond &c = cond.back(); - if (!c.parent_active) { - c.active = false; - return true; - } - if (c.taken) { - c.active = false; - } else { - c.active = true; - c.taken = true; - } - return true; - } - - if (cmd == "endif") { - if (cond.empty()) - throw std::runtime_error("#endif without #if"); - cond.pop_back(); - return true; - } - - // Unknown directive - throw std::runtime_error("Unknown directive: #" + cmd); - } + return false; + } + + if (cmd == "define") { + if (!condActive(cond)) { + return true; + } + std::string name; + iss >> name; + // Don't override predefined macros from options + if (predefined_macros.count(name)) { + return true; + } + std::string value = trim_value(iss); + macros[name] = value; + return true; + } + + if (cmd == "undef") { + if (!condActive(cond)) { + return true; + } + std::string name; + iss >> name; + // Don't undef predefined macros from options + if (predefined_macros.count(name)) { + return true; + } + macros.erase(name); + return true; + } + + if (cmd == "ifdef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = macros.count(name); + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "ifndef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = !macros.count(name); + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "if") { + std::string expr = trim_value(iss); + bool p = condActive(cond); + bool v = false; + if (p) { + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + v = ep.parse() != 0; + } + cond.push_back({ p, p && v, p && v }); + return true; + } + + if (cmd == "elif") { + std::string expr = trim_value(iss); + + if (cond.empty()) { + throw std::runtime_error("#elif without #if"); + } + + Cond & c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + + if (c.taken) { + c.active = false; + return true; + } + + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + bool v = ep.parse() != 0; + c.active = v; + if (v) { + c.taken = true; + } + return true; + } + + if (cmd == "else") { + if (cond.empty()) { + throw std::runtime_error("#else without #if"); + } + + Cond & c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + if (c.taken) { + c.active = false; + } else { + c.active = true; + c.taken = true; + } + return true; + } + + if (cmd == "endif") { + if (cond.empty()) { + throw std::runtime_error("#endif without #if"); + } + cond.pop_back(); + return true; + } + + // Unknown directive + throw std::runtime_error("Unknown directive: #" + cmd); + } }; -} // namespace pre_wgsl +} // namespace pre_wgsl -#endif // PRE_WGSL_HPP +#endif // PRE_WGSL_HPP From cf66f0a30a62d8e349bcdb43e821f1f906410890 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 2 Jun 2026 09:06:11 -0700 Subject: [PATCH 7/7] Move to functions, add a check --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 10 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 7 +- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 68 +++++----- .../flash_attn_quant_staging.tmpl | 126 +++++++----------- .../wgsl-shaders/flash_attn_tile.wgsl | 72 +++++----- .../wgsl-shaders/flash_attn_vec_split.wgsl | 80 +++++------ tests/test-backend-ops.cpp | 1 + 7 files changed, 172 insertions(+), 192 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e92741560328..a5e7de785b43 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -639,9 +639,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { } inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { - const size_t alignment = std::max(1u, storage_offset_alignment); const uint32_t offset_elems = - (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type)); return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; } @@ -652,8 +651,9 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); } -inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, const ggml_tensor * K, uint32_t kv_direct_align) { - return K->type == GGML_TYPE_F16 && (Q->ne[0] % std::max(1u, kv_direct_align) == 0) && +inline bool ggml_webgpu_flash_attn_kv_direct( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) { + return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); } @@ -667,7 +667,7 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, kv_direct_align); + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); key.has_mask = context.src3 != nullptr; key.has_sinks = context.src4 != nullptr; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 98e97949262f..c6cfb0bbbadc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3626,7 +3626,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const auto & capabilities = ctx->webgpu_global_ctx->capabilities; if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { const bool kv_direct = - ggml_webgpu_flash_attn_kv_direct(Q, K, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + ggml_webgpu_flash_attn_kv_direct(Q, K, V, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0], mask != nullptr, kv_direct); @@ -4228,8 +4228,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const const uint32_t q_tile = use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; - const bool kv_direct = - use_subgroup_matrix ? ggml_webgpu_flash_attn_kv_direct(src0, src1, capabilities.sg_mat_k) : false; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index d27e5b5f0885..9767ca3d7543 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -188,6 +188,36 @@ fn load_kx4(buf: ptr>, read_write>, scalar_index: u3 #define QUANT_OUT_TYPE f16 #include "quant_inner_loops.tmpl" #include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + kv_shmem[elem_idx] = f16(select( + 0.0, + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); + } +} +#endif #endif @compute @workgroup_size(WG_SIZE) @@ -266,23 +296,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load k tile into shared memory -#if defined(K_Q4_0) - LOAD_K_Q4_0_TILE_BLOCK -#elif defined(K_Q8_0) - LOAD_K_Q8_0_TILE_BLOCK -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - kv_shmem[elem_idx] = f16(select( - 0.0, - K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -421,23 +436,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load v tile into shared memory -#if defined(V_Q4_0) - LOAD_V_Q4_0_TILE_BLOCK -#elif defined(V_Q8_0) - LOAD_V_Q8_0_TILE_BLOCK -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - kv_shmem[elem_idx] = f16(select( - 0.0, - V[global_v_row_offset + v_col], - global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl index d46e2e79ec9c..8f41eb7bfdbc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl @@ -69,82 +69,56 @@ fn f16_from_u16(bits: u32) -> f16 { return f16(packed[0]); } -#define LOAD_K_Q4_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ - let k_row = blck_idx / BLOCKS_K; \ - let global_k_row = kv_tile + k_row; \ - let block_k = blck_idx % BLOCKS_K; \ - let row_offset = k_row * HEAD_DIM_QK; \ - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ - let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ - let q_packed = load_k_u32_at(q_byte_offset); \ - dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ - } \ -} - -#define LOAD_K_Q8_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; \ - let k_row = blck_idx / BLOCKS_K; \ - let global_k_row = kv_tile + k_row; \ - let block_k = blck_idx % BLOCKS_K; \ - let row_offset = k_row * HEAD_DIM_QK; \ - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; \ - let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_k_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; \ - let q_packed = load_k_u32_at(q_byte_offset); \ - dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); \ - } \ -} - -#define LOAD_V_Q4_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ - let v_row = blck_idx / BLOCKS_V; \ - let global_v_row = kv_tile + v_row; \ - let block_k = blck_idx % BLOCKS_V; \ - let row_offset = v_row * HEAD_DIM_V; \ - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ - let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ - let q_packed = load_v_u32_at(q_byte_offset); \ - dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ - } \ +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; + let q_packed = load_k_u32_at(q_byte_offset); +#if defined(K_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#elif defined(K_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#endif + } + } } +#endif -#define LOAD_V_Q8_0_TILE_BLOCK \ -for (var elem_idx = local_id.x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { \ - let blck_idx = elem_idx / BLOCK_SIZE; \ - let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; \ - let v_row = blck_idx / BLOCKS_V; \ - let global_v_row = kv_tile + v_row; \ - let block_k = blck_idx % BLOCKS_V; \ - let row_offset = v_row * HEAD_DIM_V; \ - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; \ - let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; \ - let d = f16_from_u16(load_v_u16_at(block_byte_base)); \ - let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; \ - let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; \ - for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { \ - let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; \ - let q_packed = load_v_u32_at(q_byte_offset); \ - dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); \ - } \ +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; + let q_packed = load_v_u32_at(q_byte_offset); +#if defined(V_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#elif defined(V_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#endif + } + } } +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 86c796faf1e5..e68934113fc1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -156,6 +156,40 @@ var p_shmem: array; #include "quant_inner_loops.tmpl" #include "flash_attn_quant_staging.tmpl" +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(k4.x); + kv_shmem[kv_off + 1u] = f16(k4.y); + kv_shmem[kv_off + 2u] = f16(k4.z); + kv_shmem[kv_off + 3u] = f16(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(v4.x); + kv_shmem[kv_off + 1u] = f16(v4.y); + kv_shmem[kv_off + 2u] = f16(v4.z); + kv_shmem[kv_off + 3u] = f16(v4.w); + } +} +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, @@ -236,23 +270,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, local_scores[slot] = FLOAT_MIN; } - #if defined(K_Q4_0) - LOAD_K_Q4_0_TILE_BLOCK -#elif defined(K_Q8_0) - LOAD_K_Q8_0_TILE_BLOCK -#else - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / Q_CHUNKS; - let chunk = vec_idx_local % Q_CHUNKS; - let global_k_row = kv_tile + kv_local; - let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; - let k4 = K[k_vec_index]; - let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; - kv_shmem[kv_off + 0u] = f16(k4.x); - kv_shmem[kv_off + 1u] = f16(k4.y); - kv_shmem[kv_off + 2u] = f16(k4.z); - kv_shmem[kv_off + 3u] = f16(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -314,23 +333,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); - #if defined(V_Q4_0) - LOAD_V_Q4_0_TILE_BLOCK -#elif defined(V_Q8_0) - LOAD_V_Q8_0_TILE_BLOCK -#else - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / V_CHUNKS; - let chunk = vec_idx_local % V_CHUNKS; - let global_v_row = kv_tile + kv_local; - let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; - let v4 = V[v_vec_index]; - let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; - kv_shmem[kv_off + 0u] = f16(v4.x); - kv_shmem[kv_off + 1u] = f16(v4.y); - kv_shmem[kv_off + 2u] = f16(v4.z); - kv_shmem[kv_off + 3u] = f16(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index e229c4904cfe..30ed97cca0c3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -240,6 +240,42 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) #define QUANT_OUT_TYPE f32 #include "quant_inner_loops.tmpl" #include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); + } +} +#endif #endif @compute @workgroup_size(WG_SIZE) @@ -323,26 +359,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load k tile into shared memory -#if defined(K_Q4_0) - LOAD_K_Q4_0_TILE_BLOCK -#elif defined(K_Q8_0) - LOAD_K_Q8_0_TILE_BLOCK -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; - let vec_idx = (global_k_row_offset + k_col) >> 2u; - let k4 = select(vec4(0.0), K[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(k4.x); - kv_shmem[elem_idx + 1u] = f32(k4.y); - kv_shmem[elem_idx + 2u] = f32(k4.z); - kv_shmem[elem_idx + 3u] = f32(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -459,26 +477,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load v tile into shared memory -#if defined(V_Q4_0) - LOAD_V_Q4_0_TILE_BLOCK -#elif defined(V_Q8_0) - LOAD_V_Q8_0_TILE_BLOCK -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; - let vec_idx = (global_v_row_offset + v_col) >> 2u; - let v4 = select(vec4(0.0), V[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(v4.x); - kv_shmem[elem_idx + 1u] = f32(v4.y); - kv_shmem[elem_idx + 2u] = f32(v4.z); - kv_shmem[elem_idx + 3u] = f32(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 58c5fdd10dbe..ba89a94fc977 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -9046,6 +9046,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_flash_attn_ext(64, 64, 4, {1, 1}, 128, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_F16)); test_cases.emplace_back(new test_flash_attn_ext(72, 72, 4, {1, 1}, 96, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)); test_cases.emplace_back(new test_flash_attn_ext(64, 64, 4, {1, 1}, 96, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_F32)); + test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 1}, 256, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0)); test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 1}, 96, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q1_0, GGML_TYPE_Q1_0)); test_cases.emplace_back(new test_flash_attn_ext(128, 64, 4, {1, 1}, 128, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q1_0, GGML_TYPE_Q4_0)); test_cases.emplace_back(new test_flash_attn_ext(64, 128, 4, {1, 1}, 128, 2, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_Q4_0, GGML_TYPE_Q1_0));