Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
483615d
Add inplace softmax
reeselevine Sep 24, 2025
27b893a
Move rms_norm to split row approach
reeselevine Sep 24, 2025
f9bb89c
Update debug for supports_op
reeselevine Sep 24, 2025
5d8e678
clean up debug statements
reeselevine Sep 30, 2025
aa1c9b2
neg f16xf32xip builds and runs, havent actually ran a model that uses…
XXjcontiniXX Oct 1, 2025
c3ae382
neg passes backend test
XXjcontiniXX Oct 1, 2025
8a6ec84
unary operators pass ggml tests
XXjcontiniXX Oct 9, 2025
7b09baa
resolving merge conflicts
XXjcontiniXX Oct 10, 2025
5360e28
rms_norm double declaration bug atoned
XXjcontiniXX Oct 10, 2025
cb08583
abides by editor-config
XXjcontiniXX Oct 10, 2025
3627499
removed vestigial files
XXjcontiniXX Oct 10, 2025
74c6add
fixed autoconfig
XXjcontiniXX Oct 10, 2025
4cf28d7
All operators (inlcluding xielu) working
XXjcontiniXX Oct 12, 2025
f9282c6
removed unnecesarry checking if node->src[1] exists for unary operators
XXjcontiniXX Oct 12, 2025
8c70b8f
responded and dealt with PR comments
XXjcontiniXX Oct 15, 2025
e1f6bae
implemented REPL_Template support and removed bug in unary operators …
XXjcontiniXX Oct 30, 2025
c41a1cb
formatted embed wgsl and ggml-webgpu.cpp
XXjcontiniXX Oct 30, 2025
c6bc125
Faster tensors (#8)
reeselevine Nov 5, 2025
7c2b2ef
Use map for shader replacements instead of pair of strings
reeselevine Nov 8, 2025
c201d0d
Wasm (#9)
reeselevine Nov 11, 2025
6db7298
Merge remote-tracking branch 'upstream/master'
reeselevine Nov 12, 2025
56e6959
Remove extra whitespace
reeselevine Nov 12, 2025
2abcdd6
Merge remote-tracking branch 'upstream/master'
reeselevine Nov 14, 2025
5bcd577
Move wasm single-thread logic out of test-backend-ops for cpu backend
reeselevine Nov 14, 2025
f9ba819
Disable multiple threads for emscripten single-thread builds in ggml_…
reeselevine Nov 17, 2025
5ca9b5e
Refactored pipelines and workgroup calculations (#10)
neha-ha Nov 18, 2025
b4663f8
Merge remote-tracking branch 'james/master' into staging
reeselevine Nov 18, 2025
e35099e
Start work on flash attention
reeselevine Dec 3, 2025
cc5ff86
Shader structure set up (many bugs still)
reeselevine Dec 16, 2025
ff4badb
debugging
reeselevine Dec 17, 2025
abbc5b2
Working first test
reeselevine Dec 17, 2025
fd1e3db
Working with head grouping, head sizes to 128, logit softcap, mask/si…
reeselevine Dec 17, 2025
2f39c2a
Generalize softmax to work with multiple subgroups, f16 accumulation,…
reeselevine Dec 19, 2025
efd49e1
Start work on integrating pre-wgsl
reeselevine Dec 29, 2025
1dc20ce
Separate structs/initial shader compilation library into separate files
reeselevine Dec 30, 2025
b072b4b
Work on compilation choices for flashattention
reeselevine Dec 30, 2025
7886418
Work on subgroup matrix/tile size portability
reeselevine Dec 31, 2025
d523a40
subgroup size agnostic online softmax
reeselevine Dec 31, 2025
af052b4
Merge remote-tracking branch 'origin/master' into fa
reeselevine Dec 31, 2025
e72c0e4
Cleanups, quantization types
reeselevine Jan 1, 2026
e36c9cd
more cleanup
reeselevine Jan 1, 2026
ef5fd1b
fix wasm build
reeselevine Jan 1, 2026
2fc8060
Refactor flashattention to increase parallelism, use direct loads for…
reeselevine Jan 3, 2026
4070a04
Checkpoint
reeselevine Jan 3, 2026
f71815f
formatting
reeselevine Jan 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
#define GGML_WEBGPU_SHADER_LIB_HPP

#include "ggml.h"
#include "pre_wgsl.hpp"

#include <string>
#include <vector>

#define GGML_WEBGPU_F16_SIZE_BYTES 2
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u

struct ggml_webgpu_flash_attn_shader_lib_context {
ggml_type kv_type;
uint32_t head_dim_qk;
uint32_t head_dim_v;
bool kv_direct;
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
uint32_t sg_mat_m;
uint32_t sg_mat_n;
uint32_t sg_mat_k;
size_t wg_mem_limit_bytes;
uint32_t max_subgroup_size;
};

struct ggml_webgpu_flash_attn_shader_decisions {
uint32_t q_tile = 0;
uint32_t kv_tile = 0;
uint32_t wg_size = 0;
};

struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
ggml_webgpu_flash_attn_shader_decisions decisions;
};

// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
uint32_t head_dim_qk,
uint32_t head_dim_v,
bool has_mask,
bool kv_direct) {
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
size_t elems = 0;
elems += q_tile * head_dim_qk; // q_shmem
if (!kv_direct) {
elems += kv_tile * max_head_dim; // kv_shmem
}
elems += q_tile * head_dim_v; // o_shmem
if (has_mask) {
elems += q_tile * kv_tile; // mask_shmem
}
elems += q_tile * kv_tile; // inter_shmem
elems += q_tile; // row_max_shmem
elems += q_tile; // exp_sum_shmem
return elems * GGML_WEBGPU_F16_SIZE_BYTES;
}

static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
const size_t limit_bytes = context.wg_mem_limit_bytes;
const size_t q_tile = context.sg_mat_m;
const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v + 2) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES;
size_t bytes_per_kv = 0;
if (!context.kv_direct) {
bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
}
if (context.has_mask) {
bytes_per_kv += q_tile;
}
bytes_per_kv += q_tile;
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
}

inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_flash_attn_shader_lib_context & context) {
std::vector<std::string> defines;
std::string variant = "flash_attn";

switch (context.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
case GGML_TYPE_F16:
defines.push_back("KV_F16");
break;
case GGML_TYPE_Q4_0:
defines.push_back("KV_Q4_0");
break;
case GGML_TYPE_Q8_0:
defines.push_back("KV_Q8_0");
break;
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
variant += std::string("_") + ggml_type_name(context.kv_type);

if (context.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
if (context.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
if (context.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}

if (context.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}

defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);

defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
variant += std::string("_hsv") + std::to_string(context.head_dim_v);

// For now these are not part of the variant name
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));

// Add chosen Q/KV tile sizes
uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));

// workgroup size
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);

defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));

ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
result.decisions.q_tile = q_tile;
result.decisions.kv_tile = kv_tile;
result.decisions.wg_size = wg_size;
return result;
}

#endif // GGML_WEBGPU_SHADER_LIB_HPP
Loading
Loading