Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ggml/src/ggml-webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR})

message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}")

# Find all WGSL files
file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl")
# Find all WGSL sources
file(GLOB WGSL_SHADER_FILES
"${SHADER_DIR}/*.wgsl"
"${SHADER_DIR}/*.tmpl"
)

# Generate the header using a Python script
add_custom_command(
Expand Down
659 changes: 347 additions & 312 deletions ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp

Large diffs are not rendered by default.

415 changes: 223 additions & 192 deletions ggml/src/ggml-webgpu/ggml-webgpu.cpp

Large diffs are not rendered by default.

44 changes: 37 additions & 7 deletions ggml/src/ggml-webgpu/pre_wgsl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,33 @@ static std::string trim(const std::string & s) {
}

static std::string trim_value(std::istream & is) {
std::string str;
std::getline(is, str);
return trim(str);
std::ostringstream ss;
ss << is.rdbuf();
return trim(ss.str());
}

static bool isIdentChar(char c) {
return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
}

static bool endsWithContinuation(const std::string & line) {
size_t i = line.size();
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
i--;
}
return i > 0 && line[i - 1] == '\\';
}

static void stripContinuation(std::string & line) {
size_t i = line.size();
while (i > 0 && std::isspace((unsigned char) line[i - 1])) {
i--;
}
if (i > 0 && line[i - 1] == '\\') {
line.erase(i - 1);
}
}

static std::string expandMacrosRecursiveInternal(const std::string & line,
const std::unordered_map<std::string, std::string> & macros,
std::unordered_set<std::string> & visiting);
Expand Down Expand Up @@ -595,19 +613,31 @@ class Preprocessor {
std::string line;

while (std::getline(in, line)) {
std::string t = trim(line);
std::string logical = line;
std::string t = trim(logical);
if (!t.empty() && t[0] == '#') {
while (endsWithContinuation(logical)) {
stripContinuation(logical);
if (!std::getline(in, line)) {
break;
}
logical += "\n";
logical += line;
}
t = trim(logical);
}

if (!t.empty() && t[0] == '#') {
bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
if (mode == DirectiveMode::IncludesOnly && !handled) {
out << line << "\n";
out << logical << "\n";
}
} else {
if (mode == DirectiveMode::IncludesOnly) {
out << line << "\n";
out << logical << "\n";
} else if (condActive(cond)) {
// Expand macros in the line before outputting
std::string expanded = expandMacrosRecursive(line, macros);
std::string expanded = expandMacrosRecursive(logical, macros);
out << expanded << "\n";
}
}
Expand Down
219 changes: 36 additions & 183 deletions ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,23 @@ enable f16;
enable subgroups;
enable chromium_experimental_subgroup_matrix;

#ifdef KV_F32
#define KV_TYPE f32
#elif defined(KV_Q4_0) || defined(KV_Q8_0)
#define KV_TYPE u32
#define BYTE_HELPERS
#include "common_decls.tmpl"

#ifdef K_F32
#define K_TYPE f32
#elif defined(K_Q4_0) || defined(K_Q8_0)
#define K_TYPE u32
#else
#define K_TYPE f16
#endif

#ifdef V_F32
#define V_TYPE f32
#elif defined(V_Q4_0) || defined(V_Q8_0)
#define V_TYPE u32
#else
#define KV_TYPE f16
#define V_TYPE f16
#endif

// Default values
Expand All @@ -30,76 +41,6 @@ enable chromium_experimental_subgroup_matrix;
// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
#define KV_BLOCKS (KV_TILE / SG_MAT_N)

// Quantization constants/helpers
#define BLOCK_SIZE 32
#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
// number of quantized elements processed per thread
#if defined(KV_Q4_0)
#define NQ 16
// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
#define F16_PER_BLOCK 9
#define BLOCK_SIZE_BYTES 18u
#define WEIGHTS_PER_F16 4
#elif defined(KV_Q8_0)
#define NQ 8
// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
#define F16_PER_BLOCK 17
#define BLOCK_SIZE_BYTES 34u
#define WEIGHTS_PER_F16 2
#endif
#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)

// Ok not to put these in a define block, compiler will remove if unused
fn get_byte(value: u32, index: u32) -> u32 {
return (value >> (index * 8)) & 0xFF;
}

fn get_byte_i32(value: u32, index: u32) -> i32 {
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
}

#if defined(KV_Q4_0) || defined(KV_Q8_0)
fn load_k_u16_at(byte_offset: u32) -> u32 {
let word = K[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}

fn load_k_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = K[word_idx];
if (shift == 0u) {
return lo;
}
let hi = K[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}

fn load_v_u16_at(byte_offset: u32) -> u32 {
let word = V[byte_offset / 4u];
let shift = (byte_offset & 2u) * 8u;
return (word >> shift) & 0xFFFFu;
}

fn load_v_u32_at(byte_offset: u32) -> u32 {
let word_idx = byte_offset / 4u;
let shift = (byte_offset & 3u) * 8u;
let lo = V[word_idx];
if (shift == 0u) {
return lo;
}
let hi = V[word_idx + 1u];
return (lo >> shift) | (hi << (32u - shift));
}

fn f16_from_u16(bits: u32) -> f16 {
let packed = unpack2x16float(bits);
return f16(packed[0]);
}
#endif

struct Params {
offset_q: u32,
offset_k: u32,
Expand Down Expand Up @@ -139,11 +80,11 @@ struct Params {

@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
#ifdef KV_OVERLAP
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
#define V K
#else
@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>;
@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>;
#endif

#if defined(MASK) && defined(SINKS)
Expand Down Expand Up @@ -238,10 +179,17 @@ fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32
return (*buf)[scalar_index >> 2u];
}

fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> {
fn load_kx4(buf: ptr<storage, array<vec4<K_TYPE>>, read_write>, scalar_index: u32) -> vec4<K_TYPE> {
return (*buf)[scalar_index >> 2u];
}

#ifndef KV_DIRECT
#define QUANT_SHMEM kv_shmem
#define QUANT_OUT_TYPE f16
#include "quant_inner_loops.tmpl"
#include "flash_attn_quant_staging.tmpl"
#endif

@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
Expand Down Expand Up @@ -311,64 +259,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}

for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile);
// clear inter_shmem to ensure zero-initialized accumulators
for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
inter_shmem[elem_idx] = 0.0;
}

// load k tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;

if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let k_row = blck_idx / BLOCKS_K;
let global_k_row = kv_tile + k_row;
let block_k = blck_idx % BLOCKS_K;
let row_offset = k_row * HEAD_DIM_QK;

if (global_k_row < params.seq_len_kv) {
let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_k_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_k_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#if defined(K_Q4_0)
LOAD_K_Q4_0_TILE_BLOCK
#elif defined(K_Q8_0)
LOAD_K_Q8_0_TILE_BLOCK
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
Expand Down Expand Up @@ -520,58 +421,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
}

// load v tile into shared memory
#if defined(KV_Q4_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;

if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte(q_packed, k);
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_lo;
kv_shmem[row_offset + idx + 16u] = q_hi;
}
}
}
}
#elif defined(KV_Q8_0)
for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
let blck_idx = elem_idx / BLOCK_SIZE;
let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
let v_row = blck_idx / BLOCKS_V;
let global_v_row = kv_tile + v_row;
let block_k = blck_idx % BLOCKS_V;
let row_offset = v_row * HEAD_DIM_V;

if (global_v_row < params.seq_len_kv) {
let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES;
let d = f16_from_u16(load_v_u16_at(block_byte_base));
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j);
let q_packed = load_v_u32_at(q_byte_offset);
for (var k = 0u; k < 4u; k++) {
let q_byte = get_byte_i32(q_packed, k);
let q_val = f16(q_byte) * d;
let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
kv_shmem[row_offset + idx] = q_val;
}
}
}
}
#if defined(V_Q4_0)
LOAD_V_Q4_0_TILE_BLOCK
#elif defined(V_Q8_0)
LOAD_V_Q8_0_TILE_BLOCK
#elif defined(KV_DIRECT)
// Direct global loads for KV
#else
Expand Down
Loading
Loading