Skip to content
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
976ebc6
naive vectorized version
ArberSephirotheca Feb 25, 2026
94abbac
add vectorized flash attention
ArberSephirotheca Mar 3, 2026
1033085
update vec version
ArberSephirotheca Mar 3, 2026
c307a4b
remove unused path and shader
ArberSephirotheca Mar 3, 2026
f8e317c
remove unused helper functions
ArberSephirotheca Mar 3, 2026
52709dd
add comments
ArberSephirotheca Mar 4, 2026
df6ef45
remove pad path
ArberSephirotheca Mar 4, 2026
838306f
ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specializa…
ArberSephirotheca Mar 7, 2026
d61ec8f
change back to vec4
ArberSephirotheca Mar 7, 2026
042a1a5
enable multi split
ArberSephirotheca Mar 7, 2026
b61e63d
enable vec path when:
ArberSephirotheca Mar 7, 2026
3602743
update flast_attn_vec_split.wgsl to reduce redundant workgroup barrie…
ArberSephirotheca Mar 8, 2026
356d6ff
enable vec path for q4 and q8
ArberSephirotheca Mar 8, 2026
1ae041d
flash-attn vec nwg=1 fast path (skip tmp/reduce staging)
ArberSephirotheca Mar 8, 2026
33a547e
use packed f16 K loads in flash-attn vec split
ArberSephirotheca Mar 8, 2026
638c49b
use packed f16 K loads in flash-attn vec split on host side
ArberSephirotheca Mar 8, 2026
0abac39
tune flash-attn vec f16 VEC_NE by head dim
ArberSephirotheca Mar 8, 2026
83a42b3
cleanup
ArberSephirotheca Mar 11, 2026
2595b1a
cleanup
ArberSephirotheca Mar 11, 2026
25096b9
keep host side clean
ArberSephirotheca Mar 11, 2026
3d6bfe0
cleanup host side
ArberSephirotheca Mar 11, 2026
68fa272
change back to original host wait/submit behavior
ArberSephirotheca Mar 11, 2026
5065dc6
formatting
ArberSephirotheca Mar 11, 2026
03d0625
reverted param-buffer pool r ecfactor
ArberSephirotheca Mar 11, 2026
5dd2a4b
add helper functions
ArberSephirotheca Mar 11, 2026
1e0d856
ggml-webgpu: move flash-attn vec pipeline caching back into shader lib
ArberSephirotheca Mar 19, 2026
88bf352
ggml-webgpu: remove duplicate functions
ArberSephirotheca Mar 19, 2026
5c2fefe
ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation
ArberSephirotheca Mar 19, 2026
59aa7d8
ggml-webgpu: revert unrelated change
ArberSephirotheca Mar 19, 2026
cac8500
ggml-webgpu: revert deleted comment
ArberSephirotheca Mar 19, 2026
ff11e38
Merge branch 'master' into backup/subgroup_size_agnostic_rebased_ggml…
ArberSephirotheca Apr 1, 2026
4e0100b
disable uniformity check
ArberSephirotheca Apr 1, 2026
56fee6e
remove unnecessary change
ArberSephirotheca Apr 1, 2026
29c09c2
Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl
reeselevine Apr 2, 2026
f40c9e7
Update ggml/src/ggml-webgpu/ggml-webgpu.cpp
reeselevine Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 191 additions & 39 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions {
uint32_t wg_size = 0;
};

struct ggml_webgpu_processed_shader {
Comment thread
reeselevine marked this conversation as resolved.
std::string wgsl;
std::string variant;
std::shared_ptr<void> decisions;
};

struct ggml_webgpu_ssm_conv_shader_decisions {
uint32_t block_size;
uint32_t tokens_per_wg;
Expand Down Expand Up @@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
bool use_vec;

bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
uses_logit_softcap == other.uses_logit_softcap;
uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
}
};

Expand All @@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
ggml_webgpu_hash_combine(seed, key.use_vec);
return seed;
}
};
Expand All @@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions {
uint32_t wg_size = 0;
};

inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
// Keep conservative defaults unless this is the f16 vec-split shape family.
if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
return 1u;
}

// Head-dim specializations used by the tuned vec f16 path.
switch (key.head_dim_qk) {
case 64: return 2u;
case 96: return 4u;
case 128: return 1u;
case 192: return 2u;
case 576: return 2u;
default: return 1u;
}
}

struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
uint32_t head_dim_v;
uint32_t wg_size;
};

struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.head_dim_v);
ggml_webgpu_hash_combine(seed, key.wg_size);
return seed;
}
};

inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
}

struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
uint32_t max_wg_size;
};

inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_reduce";

defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);

defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
variant += std::string("_wg") + std::to_string(context.max_wg_size);

ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}

struct ggml_webgpu_flash_attn_blk_pipeline_key {
uint32_t q_tile;
uint32_t kv_tile;

bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
return q_tile == other.q_tile && kv_tile == other.kv_tile;
}
};

struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.q_tile);
ggml_webgpu_hash_combine(seed, key.kv_tile);
return seed;
}
};

struct ggml_webgpu_flash_attn_blk_shader_lib_context {
ggml_webgpu_flash_attn_blk_pipeline_key key;
uint32_t max_wg_size;
};

inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn_vec_blk";

defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
variant += std::string("_qt") + std::to_string(context.key.q_tile);

defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
variant += std::string("_kvt") + std::to_string(context.key.kv_tile);

uint32_t wg_size = 1;
while ((wg_size << 1) <= context.max_wg_size) {
wg_size <<= 1;
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
variant += std::string("_wg") + std::to_string(wg_size);

ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
return result;
}

// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
Expand Down Expand Up @@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib {
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
webgpu_pipeline,
ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
flash_attn_vec_reduce_pipelines;
std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
webgpu_pipeline,
ggml_webgpu_flash_attn_blk_pipeline_key_hash>
flash_attn_blk_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
Expand Down Expand Up @@ -1675,32 +1800,16 @@ class ggml_webgpu_shader_lib {
return repeat_pipelines[key];
}

webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
const bool has_mask = context.src3 != nullptr;
const bool has_sinks = context.src4 != nullptr;

bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
(context.src1->ne[1] % context.sg_mat_n == 0);

ggml_webgpu_flash_attn_pipeline_key key = {
.kv_type = context.src1->type,
.head_dim_qk = (uint32_t) context.src0->ne[0],
.head_dim_v = (uint32_t) context.src2->ne[0],
.kv_direct = kv_direct,
.has_mask = has_mask,
.has_sinks = has_sinks,
.uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
};

auto it = flash_attn_pipelines.find(key);
webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
auto it = flash_attn_pipelines.find(context.key);
if (it != flash_attn_pipelines.end()) {
return it->second;
}

std::vector<std::string> defines;
std::string variant = "flash_attn";

switch (key.kv_type) {
switch (context.key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
Expand All @@ -1716,41 +1825,52 @@ class ggml_webgpu_shader_lib {
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(key.kv_type);
variant += std::string("_") + ggml_type_name(context.key.kv_type);

if (key.has_mask) {
if (context.key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (key.has_sinks) {
if (context.key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (key.uses_logit_softcap) {
if (context.key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
if (key.kv_direct) {
if (context.key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
if (context.key.has_mask && context.key.use_vec) {
defines.push_back("BLK");
variant += "_blk";
}

defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);

defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
variant += std::string("_hsv") + std::to_string(key.head_dim_v);
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);

defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));

uint32_t q_tile = context.sg_mat_m;
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile =
std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
context.wg_mem_limit_bytes, context.max_subgroup_size }),
std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
if (key.kv_direct) {
if (context.key.use_vec) {
q_tile = 1;
kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
}
if (context.key.kv_direct) {
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
Expand All @@ -1759,19 +1879,51 @@ class ggml_webgpu_shader_lib {
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));

uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
uint32_t wg_size = 0;
if (context.key.use_vec) {
wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
} else {
wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));

auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
webgpu_pipeline pipeline =
ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
pipeline.context = decisions;
flash_attn_pipelines[context.key] = pipeline;
return flash_attn_pipelines[context.key];
}

webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
auto it = flash_attn_blk_pipelines.find(context.key);
if (it != flash_attn_blk_pipelines.end()) {
return it->second;
}

ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
flash_attn_blk_pipelines[context.key] = pipeline;
return flash_attn_blk_pipelines[context.key];
}

webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
auto it = flash_attn_vec_reduce_pipelines.find(context.key);
if (it != flash_attn_vec_reduce_pipelines.end()) {
return it->second;
}

webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
flash_attn_pipelines[key] = pipeline;
return flash_attn_pipelines[key];
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
flash_attn_vec_reduce_pipelines[context.key] = pipeline;
return flash_attn_vec_reduce_pipelines[context.key];
}

webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
Expand Down
Loading
Loading