From 483615da761dd599ede15b18af540c523c7a642d Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 24 Sep 2025 14:56:33 -0700 Subject: [PATCH 01/40] Add inplace softmax --- ggml/include/ggml.h | 7 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 124 ++++++- .../wgsl-shaders/soft_max.tmpl.wgsl | 338 ++++++++++++++++++ ggml/src/ggml.c | 9 + tests/test-backend-ops.cpp | 18 +- 5 files changed, 483 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5028a9cebf2..34bd3204580 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1617,6 +1617,13 @@ extern "C" { float scale, float max_bias); + GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias); + GGML_API void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 93200a4d29f..3d2d92fc57a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -130,15 +130,16 @@ struct webgpu_context_struct { wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline get_rows_pipeline[30]; wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type - wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace - wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace - wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace - wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split - wgpu::ComputePipeline scale_pipeline[2]; // inplace + wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type + wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace + wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace + wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split + wgpu::ComputePipeline scale_pipeline[2]; // inplace + wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace size_t memset_bytes_per_thread; @@ -912,6 +913,79 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens ggml_op_name(dst->op)); } +static void ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here + const int has_sink = (src2 != nullptr); + float max_bias; + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + mask_type < 2 ? (uint32_t) src1->ne[2] : 0, + mask_type < 2 ? (uint32_t) src1->ne[3] : 0, + *(uint32_t *) dst->op_params, // scale + *(uint32_t *) &max_bias, + *(uint32_t *) &n_head_log2, + *(uint32_t *) &m0, + *(uint32_t *) &m1 + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) } + }; + uint32_t binding_num = 1; + if (mask_type < 2) { + entries.push_back({ .binding = binding_num, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + binding_num++; + } + if (has_sink) { + entries.push_back({ .binding = binding_num, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + binding_num++; + } + if (!inplace) { + entries.push_back({ .binding = binding_num, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, + ggml_nrows(dst), ggml_op_name(dst->op)); +} + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -1512,6 +1586,38 @@ static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { "scale_f32_inplace", constants); } +static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = 64; + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32, + "soft_max_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace, + "soft_max_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink, + "soft_max_f32_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1], + wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32, + "soft_max_f32_mask_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1], + wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16, + "soft_max_f32_mask_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1], + wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0], + wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1], + wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0], + wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1], + wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", + constants); +} + static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl new file mode 100644 index 00000000000..c62988d4845 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl @@ -0,0 +1,338 @@ +#define(VARIANTS) +[ + { + "SHADER_NAME": "soft_max_f32", + "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_inplace", + "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_sink", + "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_sink_inplace", + "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32_inplace", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16_inplace", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32_sink", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", + "REPLS": { + "MASK_TYPE" : "f32", + }, + "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16_sink", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] + }, + { + "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", + "REPLS": { + "MASK_TYPE" : "f16", + }, + "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] + } +] +#end(VARIANTS) + +#define(DECLS) + +#decl(BASE_BINDINGS) +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#enddecl(BASE_BINDINGS) + +#decl(BASE_BINDINGS_INPLACE) +@group(0) @binding(1) +var params: Params; +#enddecl(BASE_BINDINGS_INPLACE) + +#decl(SINK_BINDINGS) +@group(0) @binding(1) +var sinks: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#enddecl(SINK_BINDINGS) + +#decl(SINK_BINDINGS_INPLACE) +@group(0) @binding(1) +var sinks: array; + +@group(0) @binding(2) +var params: Params; +#enddecl(SINK_BINDINGS_INPLACE) + +#decl(MASK_BINDINGS) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#enddecl(MASK_BINDINGS) + +#decl(MASK_BINDINGS_INPLACE) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var params: Params; +#enddecl(MASK_BINDINGS_INPLACE) + +#decl(MASK_SINK_BINDINGS) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var sinks: array; + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; +#enddecl(MASK_SINK_BINDINGS) + +#decl(MASK_SINK_BINDINGS_INPLACE) +@group(0) @binding(1) +var mask: array<{{MASK_TYPE}}>; + +@group(0) @binding(2) +var sinks: array; + +@group(0) @binding(3) +var params: Params; +#enddecl(MASK_SINK_BINDINGS_INPLACE) + +#decl(NOT_INPLACE) +fn inter_value(i: u32) -> f32 { + return dst[i]; +} + +fn update(i: u32, val: f32) { + dst[i] = val; +} +#enddecl(NOT_INPLACE) + +#decl(INPLACE) +fn inter_value(i: u32) -> f32 { + return src[i]; +} + +fn update(i: u32, val: f32) { + src[i] = val; +} +#enddecl(INPLACE) + +#decl(NO_MASK) +fn mask_val(i: u32) -> f32 { + return 0.0; +} +#enddecl(NO_MASK) + +#decl(MASK) +fn mask_val(i: u32) -> f32 { + return f32(mask[i]); +} +#enddecl(MASK) + +#decl(NO_SINK) +fn lower_max_bound(i2: u32) -> f32 { + return -1e30; +} + +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val; +} +#enddecl(NO_SINK) + +#decl(SINK) +fn lower_max_bound(i2: u32) -> f32 { + return sinks[params.offset_sinks + i2]; +} + +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val + exp(sinks[params.offset_sinks + i2] - max_val); +} +#enddecl(SINK) + +#end(DECLS) + +#define(SHADER) +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_sinks: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of src0/dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + // shape of src1 + ne12: u32, + ne13: u32, + + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) +var src: array; + +DECLS + +const CACHE_SIZE: u32 = 16; + +override wg_size: u32; +var scratch: array; + +@compute @workgroup_size(wg_size) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let elems = (params.ne0 + wg_size - 1) / wg_size; + + let head = f32(i2); + let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); + + var cache: array; + + var max_val = lower_max_bound(i2); + for (var j: u32 = 0; j < elems; j++) { + let col = j * wg_size + lid.x; + if (col < params.ne0) { + let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); + max_val = max(max_val, val); + if (col < CACHE_SIZE) { + cache[col] = val; + } + } + } + + scratch[lid.x] = max_val; + workgroupBarrier(); + var offset = wg_size / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); + } + offset = offset / 2; + workgroupBarrier(); + } + let row_max = scratch[0]; + + var sum = 0.0f; + for (var j: u32 = 0; j < elems; j++) { + let col = j * wg_size + lid.x; + if (col < params.ne0) { + let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), + cache[col], col < CACHE_SIZE); + let ex = exp(val - row_max); + sum += ex; + if (col < CACHE_SIZE) { + cache[col] = ex; + } else { + update(i_dst_row + col, ex); + } + } + } + + scratch[lid.x] = sum; + workgroupBarrier(); + offset = wg_size / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + let row_sum = add_sinks(scratch[0], i2, row_max); + + let sum_recip = 1.0 / row_sum; + for (var j: u32 = 0; j < elems; j++) { + let col = j * wg_size + lid.x; + if (col < params.ne0) { + update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); + } + } +} +#end(SHADER) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aecbdad5a3d..d753c9a18ac 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3829,6 +3829,15 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } +struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias) { + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true); +} + void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 62d815cc268..6b751db1149 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3752,9 +3752,10 @@ struct test_soft_max : public test_case { const std::array nr23; // broadcast only dims 2 and 3 const float scale; const float max_bias; + const bool inplace; std::string vars() override { - return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias); + return VARS_TO_STR9(type, ne, mask, sinks, m_prec, nr23, scale, max_bias, inplace); } // the 1024 test with bias occasionally fails: @@ -3770,8 +3771,9 @@ struct test_soft_max : public test_case { ggml_type m_prec = GGML_TYPE_F32, std::array nr23 = {1, 1}, float scale = 1.0f, - float max_bias = 0.0f) - : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {} + float max_bias = 0.0f, + bool inplace = false) + : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]); @@ -3790,7 +3792,12 @@ struct test_soft_max : public test_case { ggml_set_name(sinks, "sinks"); } - ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); + ggml_tensor * out; + if (inplace) { + out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias); + } else { + out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias); + } ggml_soft_max_add_sinks(out, sinks); ggml_set_name(out, "out"); @@ -6562,6 +6569,9 @@ static std::vector> make_test_cases_eval() { } } } + // inplace tests + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f, true)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f, true)); } } test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f)); From 27b893a6f86cbcb1b01e43c4095f4fec3b886dce Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 24 Sep 2025 15:48:26 -0700 Subject: [PATCH 02/40] Move rms_norm to split row approach --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 48 ++++++++++--------- .../ggml-webgpu/wgsl-shaders/rms_norm.wgsl | 43 +++++++++++++---- .../wgsl-shaders/soft_max.tmpl.wgsl | 48 +++++++++++-------- 3 files changed, 87 insertions(+), 52 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3d2d92fc57a..3e25fcb246c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -28,6 +28,7 @@ /* Constants */ #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16 +#define WEBGPU_WAIT_ANY_BATCH_SIZE 64 #define WEBGPU_MUL_MAT_WG_SIZE 64 #define WEBGPU_NUM_PARAM_BUFS 100 #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters @@ -35,6 +36,9 @@ #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +// For operations which process a row in parallel, this seems like a reasonable default +#define WEBGPU_ROW_SPLIT_WG_SIZE 64 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -257,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { }), UINT64_MAX); } else { - // existing callbacks, wait on them - ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX); + // WebGPU implementations may limit the number of futures that can be waited on at once, + // so wait in batches (64 is what Dawn supports). + for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) { + size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size()); + ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX); + } ctx->callback_futures.clear(); } } @@ -727,9 +735,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x, + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src), ggml_op_name(dst->op)); } @@ -1311,11 +1317,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } -// The max workgroup size is a common constant -static std::vector ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) { +// Workgroup size is a common constant +static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { std::vector constants(1); constants[0].key = "wg_size"; - constants[0].value = webgpu_ctx->max_wg_size_x; + constants[0].value = wg_size; return constants; } @@ -1383,11 +1389,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows", - ggml_webgpu_max_wg_size_entry(webgpu_ctx)); + ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32, @@ -1437,7 +1443,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], wgsl_cpy_f32_f32, "cpy_f32_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], @@ -1449,7 +1455,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", @@ -1461,7 +1467,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", @@ -1473,7 +1479,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", @@ -1485,7 +1491,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", @@ -1497,7 +1503,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, @@ -1505,7 +1511,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, "rope_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], @@ -1525,7 +1531,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); // reglu ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], wgsl_reglu_f32, "reglu_f32", constants); @@ -1579,7 +1585,7 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, @@ -1587,9 +1593,7 @@ static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = 64; + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32, "soft_max_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index a275eeb9783..4f72bb1c851 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -71,14 +71,14 @@ var src: array; DECLS override wg_size: u32; +var scratch: array; + @compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne1 * params.ne2 * params.ne3) { - return; - } +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { // one thread per row - var i = gid.x; + var i = wid.x; let i3 = i / (params.ne2 * params.ne1); i = i % (params.ne2 * params.ne1); let i2 = i / params.ne1; @@ -86,13 +86,38 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let elems = (params.ne0 + wg_size - 1) / wg_size; + var sum = 0.0f; - for (var j: u32 = 0; j < params.ne0; j++) { - sum += src[i_src_row + j] * src[i_src_row + j]; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + sum += pow(src[i_src_row + col], 2.0); + col += wg_size; } + + scratch[lid.x] = sum; + workgroupBarrier(); + var offset = wg_size / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); - for (var j: u32 = 0; j < params.ne0; j++) { - update(i_src_row + j, i_dst_row + j, scale); + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_src_row + col, i_dst_row + col, scale); + col += wg_size; } } #end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl index c62988d4845..64ab576c083 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl @@ -276,15 +276,17 @@ fn main(@builtin(workgroup_id) wid: vec3, var cache: array; var max_val = lower_max_bound(i2); + var col = lid.x; for (var j: u32 = 0; j < elems; j++) { - let col = j * wg_size + lid.x; - if (col < params.ne0) { - let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); - max_val = max(max_val, val); - if (col < CACHE_SIZE) { - cache[col] = val; - } + if (col >= params.ne0) { + break; } + let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); + max_val = max(max_val, val); + if (col < CACHE_SIZE) { + cache[col] = val; + } + col += wg_size; } scratch[lid.x] = max_val; @@ -300,19 +302,21 @@ fn main(@builtin(workgroup_id) wid: vec3, let row_max = scratch[0]; var sum = 0.0f; + col = lid.x; for (var j: u32 = 0; j < elems; j++) { - let col = j * wg_size + lid.x; - if (col < params.ne0) { - let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), - cache[col], col < CACHE_SIZE); - let ex = exp(val - row_max); - sum += ex; - if (col < CACHE_SIZE) { - cache[col] = ex; - } else { - update(i_dst_row + col, ex); - } + if (col >= params.ne0) { + break; + } + let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), + cache[col], col < CACHE_SIZE); + let ex = exp(val - row_max); + sum += ex; + if (col < CACHE_SIZE) { + cache[col] = ex; + } else { + update(i_dst_row + col, ex); } + col += wg_size; } scratch[lid.x] = sum; @@ -328,11 +332,13 @@ fn main(@builtin(workgroup_id) wid: vec3, let row_sum = add_sinks(scratch[0], i2, row_max); let sum_recip = 1.0 / row_sum; + col = lid.x; for (var j: u32 = 0; j < elems; j++) { - let col = j * wg_size + lid.x; - if (col < params.ne0) { - update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); + if (col >= params.ne0) { + break; } + update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); + col += wg_size; } } #end(SHADER) From f9bb89c63382287100dbaaa6bfaafeba2f38d258 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 24 Sep 2025 16:15:55 -0700 Subject: [PATCH 03/40] Update debug for supports_op --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3e25fcb246c..5b7a0ddefcd 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1703,6 +1703,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; + ggml_tensor * src2 = op->src[2]; // on smaller devices (or CI), tensors may be larger than the max storage buffer size if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || @@ -1733,7 +1734,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); break; case GGML_OP_SET_ROWS: - supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64); + supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64); break; case GGML_OP_GET_ROWS: if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || @@ -1808,11 +1809,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const default: break; } + if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || + (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || + (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) || + (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + supports_op = false; +#ifdef GGML_WEBGPU_DEBUG + WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: "); +#endif + } + #ifdef GGML_WEBGPU_DEBUG if (!supports_op) { - WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) - << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") - << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); + WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: " + << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) + << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") + << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); + } else { + WEBGPU_LOG_DEBUG("ggml_webgpu op supported: " + << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) + << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") + << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); } #endif return supports_op; From 5d8e6784e249ea6dc5f13a556e8b3db7b5885584 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 30 Sep 2025 10:11:43 -0700 Subject: [PATCH 04/40] clean up debug statements --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5b7a0ddefcd..de68c5689bb 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1814,12 +1814,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) || (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { supports_op = false; -#ifdef GGML_WEBGPU_DEBUG WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: "); -#endif } -#ifdef GGML_WEBGPU_DEBUG if (!supports_op) { WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type) @@ -1831,7 +1828,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); } -#endif return supports_op; } From aa1c9b2f8877a405470ca56709c42a1fd43713de Mon Sep 17 00:00:00 2001 From: James Contini Date: Tue, 30 Sep 2025 23:55:27 -0700 Subject: [PATCH 05/40] neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 60 +++++++++++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl | 41 +++++++++++++ .../wgsl-shaders/neg_in_place.wgsl | 38 ++++++++++++ 3 files changed, 139 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index de68c5689bb..24ed1fe8aec 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -144,6 +144,8 @@ struct webgpu_context_struct { wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split wgpu::ComputePipeline scale_pipeline[2]; // inplace wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace + wgpu::ComputePipeline neg_pipeline; + wgpu::ComputePipeline neg_ip_pipeline; size_t memset_bytes_per_thread; @@ -992,6 +994,36 @@ static void ggml_webgpu_soft_max(webgpu_context & ctx, ggml_nrows(dst), ggml_op_name(dst->op)); } +static void ggml_webgpu_neg( webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * dst, + wgpu::ComputePipeline & pipeline, + bool in_place) { + std::vector params = { + (uint32_t) ggml_nelements(dst) + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + + }; + if (!in_place) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -1060,6 +1092,22 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_SCALE: ggml_webgpu_scale(ctx, src0, node); break; + case GGML_OP_UNARY: { + // if unary, switch on unary operators + const ggml_unary_op unary_op = ggml_get_unary_op(node); + switch (unary_op) { + case GGML_UNARY_OP_NEG: + if (ggml_webgpu_tensor_equal(src0, node)) { + ggml_webgpu_neg(ctx, src0, node, ctx->neg_ip_pipeline, true); + } else { + ggml_webgpu_neg(ctx, src0, src1, ctx->neg_pipeline, false); + } + break; + default: + return false; + } + break; + } default: return false; } @@ -1622,6 +1670,18 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { constants); } +static void ggml_webgpu_init_neg_pipeline(webgpu_context & webgpu_ctx) { + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_pipeline, wgsl_neg_f32, "neg_f32", + ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_pipeline, wgsl_neg_f16, "neg_f16", + ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_ip_pipeline, wgsl_neg_in_place_f32, "neg_in_place_f32", + ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_ip_pipeline, wgsl_neg_in_place_f16, "neg_in_place_f16", + ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + +} + static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl new file mode 100644 index 00000000000..7aa2a75dddc --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl @@ -0,0 +1,41 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +@group(0) @binding(0) +var src: array<{{TYPE}}>; + +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + dst[gid.x] = -src[gid.x]; + } + +} + +#end(SHADER) \ No newline at end of file diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl new file mode 100644 index 00000000000..1ca0b3a76be --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl @@ -0,0 +1,38 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + } + }, + { + "REPLS": { + "TYPE" : "f16", + } + } +] + +#end(VARIANTS) + +#define(SHADER) + +enable f16; + +@group(0) @binding(0) +var src: array<{{TYPE}}>; + +@group(0) @binding(1) +var params: Params; + + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + src[gid.x] = -src[gid.x]; + } + +} + +#end(SHADER) \ No newline at end of file From c3ae38278a2db236adc5912c9140e4f0d63f2c19 Mon Sep 17 00:00:00 2001 From: James Contini Date: Wed, 1 Oct 2025 16:22:40 -0700 Subject: [PATCH 06/40] neg passes backend test --- ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl | 64 ++++++++++++++++--- .../wgsl-shaders/neg_in_place.wgsl | 64 ++++++++++++++++--- 2 files changed, 110 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl index 7aa2a75dddc..23feb9aa7da 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl @@ -3,20 +3,19 @@ [ { "REPLS": { - "TYPE" : "f32", + "TYPE": "f32", } }, { "REPLS": { - "TYPE" : "f16", + "TYPE": "f16", } - } + }, ] #end(VARIANTS) #define(SHADER) - enable f16; @group(0) @binding(0) @@ -25,17 +24,64 @@ var src: array<{{TYPE}}>; @group(0) @binding(1) var dst: array<{{TYPE}}>; +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + @group(0) @binding(2) var params: Params; - override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[gid.x] = -src[gid.x]; + if (gid.x >= params.ne) { + return; } -} + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; -#end(SHADER) \ No newline at end of file + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx])); +} +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl index 1ca0b3a76be..732b56cea23 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl @@ -3,36 +3,82 @@ [ { "REPLS": { - "TYPE" : "f32", + "TYPE": "f32", } }, { "REPLS": { - "TYPE" : "f16", + "TYPE": "f16", } - } + }, ] #end(VARIANTS) #define(SHADER) - enable f16; @group(0) @binding(0) var src: array<{{TYPE}}>; +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + @group(0) @binding(1) var params: Params; - override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src[gid.x] = -src[gid.x]; + if (gid.x >= params.ne) { + return; } -} + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; -#end(SHADER) \ No newline at end of file + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx])); +} +#end(SHADER) From 8a6ec843a50ab82f8cef59b4558eb63f318ba02d Mon Sep 17 00:00:00 2001 From: James Contini Date: Wed, 8 Oct 2025 18:06:47 -0700 Subject: [PATCH 07/40] unary operators pass ggml tests --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 393 +++++++++++---- .../ggml-webgpu/wgsl-shaders/unary_op.wgsl | 467 ++++++++++++++++++ 2 files changed, 770 insertions(+), 90 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 24ed1fe8aec..2d1b6bd518e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -144,8 +144,24 @@ struct webgpu_context_struct { wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split wgpu::ComputePipeline scale_pipeline[2]; // inplace wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace - wgpu::ComputePipeline neg_pipeline; - wgpu::ComputePipeline neg_ip_pipeline; + wgpu::ComputePipeline unary_pipeline[16][2][2]; + + +/* wgpu::ComputePipeline abs_pipeline[2][2]; // abs + wgpu::ComputePipeline sgn_pipeline[2][2]; // sgn + wgpu::ComputePipeline neg_pipeline[2][2]; // neg + wgpu::ComputePipeline step_pipeline[2][2]; // step + wgpu::ComputePipeline tanh_pipeline[2][2]; // tanh + wgpu::ComputePipeline elu_pipeline[2][2]; // elu + wgpu::ComputePipeline relu_pipeline[2][2]; // relu + wgpu::ComputePipeline sigmoid_pipeline[2][2]; // sigmoid + wgpu::ComputePipeline gelu_pipeline[2][2]; // gelu + wgpu::ComputePipeline gelu_quick_pipeline[2][2]; // gelu_quick + wgpu::ComputePipeline silu_pipeline[2][2]; // silu (a.k.a. swish) + wgpu::ComputePipeline hardswish_pipeline[2][2]; // hardswish + wgpu::ComputePipeline hardsigmoid_pipeline[2][2]; // hardsigmoid + wgpu::ComputePipeline exp_pipeline[2][2]; // exp + wgpu::ComputePipeline gelu_erf_pipeline[2][2]; // gelu_erf */ size_t memset_bytes_per_thread; @@ -250,6 +266,7 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, // Wait for the queue to finish processing all submitted work static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { + std::lock_guard lock(ctx->mutex); if (ctx->callback_futures.empty()) { // no existing callbacks, wait on queue submission @@ -274,6 +291,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) { } static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) { + std::lock_guard lock(ctx->mutex); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_submit_queue()"); if (ctx->staged_command_bufs.empty()) { @@ -373,6 +391,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context & uint32_t wg_x, const char * bind_group_label = nullptr, bool submit_and_wait = false) { + webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -491,39 +510,6 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } -static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - uint32_t ne = (uint32_t) ggml_nelements(dst); - - std::vector params = { - ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shapes - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] - }; - - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; - - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, - ggml_op_name(dst->op)); -} - static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { @@ -659,6 +645,83 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t ggml_op_name(dst->op)); } +static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, + ggml_op_name(dst->op)); +} + +static void ggml_webgpu_unary_op( webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * dst, + wgpu::ComputePipeline & pipeline, + bool in_place) { + + + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + }; + + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + + }; + if (!in_place) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + +} + static void ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -994,38 +1057,12 @@ static void ggml_webgpu_soft_max(webgpu_context & ctx, ggml_nrows(dst), ggml_op_name(dst->op)); } -static void ggml_webgpu_neg( webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * dst, - wgpu::ComputePipeline & pipeline, - bool in_place) { - std::vector params = { - (uint32_t) ggml_nelements(dst) - }; - - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - - }; - if (!in_place) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - - ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); -} // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { + if (ggml_is_empty(node)) { return false; } @@ -1035,6 +1072,8 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_tensor * src1 = node->src[1]; ggml_tensor * src2 = node->src[2]; + + switch (node->op) { // no-ops case GGML_OP_NONE: @@ -1092,29 +1131,23 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_SCALE: ggml_webgpu_scale(ctx, src0, node); break; - case GGML_OP_UNARY: { - // if unary, switch on unary operators - const ggml_unary_op unary_op = ggml_get_unary_op(node); - switch (unary_op) { - case GGML_UNARY_OP_NEG: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_neg(ctx, src0, node, ctx->neg_ip_pipeline, true); - } else { - ggml_webgpu_neg(ctx, src0, src1, ctx->neg_pipeline, false); - } - break; - default: - return false; + case GGML_OP_UNARY: + { + const ggml_unary_op UNARY_OP = ggml_get_unary_op(node); + int in_place = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place); + + break; } - break; - } + default: return false; - } + } return true; } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); @@ -1296,6 +1329,8 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); @@ -1307,6 +1342,8 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); + + return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } @@ -1670,19 +1707,162 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { constants); } -static void ggml_webgpu_init_neg_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_pipeline, wgsl_neg_f32, "neg_f32", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_pipeline, wgsl_neg_f16, "neg_f16", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_ip_pipeline, wgsl_neg_in_place_f32, "neg_in_place_f32", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->neg_ip_pipeline, wgsl_neg_in_place_f16, "neg_in_place_f16", - ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); - +static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); + + // ABS + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0], + wgsl_abs_f32, "abs_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0], + wgsl_abs_f16, "abs_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1], + wgsl_abs_in_place_f32, "abs_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1], + wgsl_abs_in_place_f16, "abs_in_place_f16", constants); + + // SGN + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0], + wgsl_sgn_f32, "sgn_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0], + wgsl_sgn_f16, "sgn_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1], + wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1], + wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants); + + // NEG + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0], + wgsl_neg_f32, "neg_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0], + wgsl_neg_f16, "neg_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1], + wgsl_neg_in_place_f32, "neg_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1], + wgsl_neg_in_place_f16, "neg_in_place_f16", constants); + + // STEP + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0], + wgsl_step_f32, "step_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0], + wgsl_step_f16, "step_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1], + wgsl_step_in_place_f32, "step_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1], + wgsl_step_in_place_f16, "step_in_place_f16", constants); + + // TANH + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0], + wgsl_tanh_f32, "tanh_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0], + wgsl_tanh_f16, "tanh_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1], + wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1], + wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants); + + // ELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0], + wgsl_elu_f32, "elu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0], + wgsl_elu_f16, "elu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1], + wgsl_elu_in_place_f32, "elu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1], + wgsl_elu_in_place_f16, "elu_in_place_f16", constants); + + // RELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0], + wgsl_relu_f32, "relu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0], + wgsl_relu_f16, "relu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1], + wgsl_relu_in_place_f32, "relu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1], + wgsl_relu_in_place_f16, "relu_in_place_f16", constants); + + // SIGMOID + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0], + wgsl_sigmoid_f32, "sigmoid_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0], + wgsl_sigmoid_f16, "sigmoid_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1], + wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1], + wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants); + + // GELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0], + wgsl_gelu_f32, "gelu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0], + wgsl_gelu_f16, "gelu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1], + wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1], + wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants); + + // GELU_QUICK + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0], + wgsl_gelu_quick_f32, "gelu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0], + wgsl_gelu_quick_f16, "gelu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1], + wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1], + wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants); + + // SILU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0], + wgsl_silu_f32, "silu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0], + wgsl_silu_f16, "silu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1], + wgsl_silu_in_place_f32, "silu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1], + wgsl_silu_in_place_f16, "silu_in_place_f16", constants); + + // HARDSWISH + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0], + wgsl_hardswish_f32, "hardswish_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0], + wgsl_hardswish_f16, "hardswish_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1], + wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1], + wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants); + + // HARDSIGMOID + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0], + wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0], + wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1], + wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1], + wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants); + + // EXP + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0], + wgsl_exp_f32, "exp_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0], + wgsl_exp_f16, "exp_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1], + wgsl_exp_in_place_f32, "exp_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1], + wgsl_exp_in_place_f16, "exp_in_place_f16", constants); + + // GELU_ERF + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0], + wgsl_gelu_erf_f32, "gelu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0], + wgsl_gelu_erf_f16, "gelu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1], + wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1], + wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); @@ -1701,12 +1881,13 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co /* .device = */ dev, /* .context = */ &backend_ctx, }; - + //tried return &backend; } static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { // See GGML Backend Buffer Type Interface section + static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, @@ -1757,6 +1938,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { } static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); webgpu_context webgpu_ctx = ctx->webgpu_ctx; @@ -1866,6 +2048,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SCALE: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_UNARY: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && + (src1 ? (src1->type == op->type) : true); + break; default: break; } @@ -1888,6 +2074,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null") << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")); } + + return supports_op; } @@ -1929,6 +2117,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); + + ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); webgpu_context ctx = reg_ctx->webgpu_ctx; @@ -1996,6 +2186,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + ggml_webgpu_init_memset_pipeline(ctx); ggml_webgpu_init_mul_mat_pipeline(ctx); ggml_webgpu_init_set_rows_pipeline(ctx); @@ -2009,6 +2200,24 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_rope_pipeline(ctx); ggml_webgpu_init_glu_pipeline(ctx); ggml_webgpu_init_scale_pipeline(ctx); + ggml_webgpu_init_unary_pipeline(ctx); + + +/* ggml_webgpu_init_abs_pipeline(ctx); + ggml_webgpu_init_sgn_pipeline(ctx); + ggml_webgpu_init_neg_pipeline(ctx); + ggml_webgpu_init_step_pipeline(ctx); + ggml_webgpu_init_tanh_pipeline(ctx); + ggml_webgpu_init_elu_pipeline(ctx); + ggml_webgpu_init_relu_pipeline(ctx); + ggml_webgpu_init_sigmoid_pipeline(ctx); + ggml_webgpu_init_gelu_pipeline(ctx); + ggml_webgpu_init_gelu_quick_pipeline(ctx); + ggml_webgpu_init_silu_pipeline(ctx); + ggml_webgpu_init_hardswish_pipeline(ctx); + ggml_webgpu_init_hardsigmoid_pipeline(ctx); + ggml_webgpu_init_exp_pipeline(ctx); + ggml_webgpu_init_gelu_erf_pipeline(ctx); */ #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers @@ -2035,6 +2244,8 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t /* .reg = */ reg, /* .context = */ &device_ctx, }; + + return &device; } @@ -2048,6 +2259,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* End GGML Backend Registration Interface */ ggml_backend_reg_t ggml_backend_webgpu_reg() { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); webgpu_context webgpu_ctx = std::make_shared(); @@ -2073,8 +2285,9 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } ggml_backend_t ggml_backend_webgpu_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - + return ggml_backend_webgpu_device_init(dev, nullptr); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl new file mode 100644 index 00000000000..7b78759dd00 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -0,0 +1,467 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "abs_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "abs_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "abs_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "abs_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sgn_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sgn_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "neg_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "neg_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "step_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "step_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "tanh_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "tanh_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "elu_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "elu_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "relu_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "relu_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913)));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i])));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "silu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "silu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardswish_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardswish_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "exp_f32", + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_f16", + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "exp_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + }, + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(dst_i: u32, src_i: u32) { + {{FUNC}} +} + +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(dst_i: u32, src_i: u32) { + {{FUNC}} // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458 +} + +@group(0) @binding(1) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + + +#define(SHADER) + +enable f16; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(0) +var src: array<{{TYPE}}>; + + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + + update(params.offset_dst + dst_idx, params.offset_src + src_idx); +} + +#end(SHADER) From 5360e2852a4b51197d7d67d0a5d42e908b02d7ed Mon Sep 17 00:00:00 2001 From: James Contini Date: Fri, 10 Oct 2025 12:45:57 -0700 Subject: [PATCH 08/40] rms_norm double declaration bug atoned --- ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index 7d99fb70de5..712b921f1ab 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -88,8 +88,6 @@ fn main(@builtin(workgroup_id) wid: vec3, let elems = (params.ne0 + wg_size - 1) / wg_size; - let elems = (params.ne0 + wg_size - 1) / wg_size; - var sum = 0.0f; var col = lid.x; for (var j: u32 = 0; j < elems; j++) { From cb0858333785757804c5104e59c4981843207c16 Mon Sep 17 00:00:00 2001 From: James Contini Date: Fri, 10 Oct 2025 12:59:32 -0700 Subject: [PATCH 09/40] abides by editor-config --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 194 ++++++++++++++++++++++++--- 1 file changed, 173 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 57d085b31d0..b5ec7816751 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -261,8 +261,8 @@ struct webgpu_context_struct { webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split webgpu_pipeline scale_pipeline[2]; // inplace - wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace - wgpu::ComputePipeline unary_pipeline[16][2][2]; + webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace + webgpu_pipeline unary_pipeline[16][2][2]; size_t memset_bytes_per_thread; @@ -719,7 +719,7 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); - + } @@ -866,9 +866,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst, - wgpu::ComputePipeline & pipeline, + webgpu_pipeline & pipeline, bool in_place) { - + uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -902,8 +902,8 @@ static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1250,7 +1250,7 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * src1 = node->src[1]; ggml_tensor * src2 = node->src[2]; - + switch (node->op) { // no-ops @@ -1299,13 +1299,11 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_scale(ctx, src0, node); case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); - case GGML_OP_UNARY: + case GGML_OP_UNARY: { const ggml_unary_op UNARY_OP = ggml_get_unary_op(node); int in_place = ggml_webgpu_tensor_equal(src0, node); - ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place); - - break; + return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place); } default: @@ -1314,7 +1312,7 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); @@ -1534,7 +1532,7 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); @@ -1873,6 +1871,160 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } +static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); + + // ABS + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0], + wgsl_abs_f32, "abs_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0], + wgsl_abs_f16, "abs_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1], + wgsl_abs_in_place_f32, "abs_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1], + wgsl_abs_in_place_f16, "abs_in_place_f16", constants); + + // SGN + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0], + wgsl_sgn_f32, "sgn_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0], + wgsl_sgn_f16, "sgn_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1], + wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1], + wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants); + + // NEG + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0], + wgsl_neg_f32, "neg_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0], + wgsl_neg_f16, "neg_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1], + wgsl_neg_in_place_f32, "neg_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1], + wgsl_neg_in_place_f16, "neg_in_place_f16", constants); + + // STEP + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0], + wgsl_step_f32, "step_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0], + wgsl_step_f16, "step_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1], + wgsl_step_in_place_f32, "step_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1], + wgsl_step_in_place_f16, "step_in_place_f16", constants); + + // TANH + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0], + wgsl_tanh_f32, "tanh_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0], + wgsl_tanh_f16, "tanh_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1], + wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1], + wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants); + + // ELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0], + wgsl_elu_f32, "elu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0], + wgsl_elu_f16, "elu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1], + wgsl_elu_in_place_f32, "elu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1], + wgsl_elu_in_place_f16, "elu_in_place_f16", constants); + + // RELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0], + wgsl_relu_f32, "relu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0], + wgsl_relu_f16, "relu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1], + wgsl_relu_in_place_f32, "relu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1], + wgsl_relu_in_place_f16, "relu_in_place_f16", constants); + + // SIGMOID + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0], + wgsl_sigmoid_f32, "sigmoid_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0], + wgsl_sigmoid_f16, "sigmoid_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1], + wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1], + wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants); + + // GELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0], + wgsl_gelu_f32, "gelu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0], + wgsl_gelu_f16, "gelu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1], + wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1], + wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants); + + // GELU_QUICK + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0], + wgsl_gelu_quick_f32, "gelu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0], + wgsl_gelu_quick_f16, "gelu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1], + wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1], + wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants); + + // SILU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0], + wgsl_silu_f32, "silu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0], + wgsl_silu_f16, "silu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1], + wgsl_silu_in_place_f32, "silu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1], + wgsl_silu_in_place_f16, "silu_in_place_f16", constants); + + // HARDSWISH + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0], + wgsl_hardswish_f32, "hardswish_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0], + wgsl_hardswish_f16, "hardswish_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1], + wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1], + wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants); + + // HARDSIGMOID + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0], + wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0], + wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1], + wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1], + wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants); + + // EXP + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0], + wgsl_exp_f32, "exp_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0], + wgsl_exp_f16, "exp_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1], + wgsl_exp_in_place_f32, "exp_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1], + wgsl_exp_in_place_f16, "exp_in_place_f16", constants); + + // GELU_ERF + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0], + wgsl_gelu_erf_f32, "gelu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0], + wgsl_gelu_erf_f16, "gelu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1], + wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1], + wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants); +} + static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", @@ -1912,7 +2064,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { - + GGML_UNUSED(params); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); @@ -1937,7 +2089,7 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { // See GGML Backend Buffer Type Interface section - + static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, @@ -1988,7 +2140,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { } static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - + ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); webgpu_context webgpu_ctx = ctx->webgpu_ctx; @@ -2250,7 +2402,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - + ggml_webgpu_init_memset_pipeline(ctx); ggml_webgpu_init_mul_mat_pipeline(ctx); ggml_webgpu_init_set_rows_pipeline(ctx); @@ -2307,7 +2459,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* End GGML Backend Registration Interface */ ggml_backend_reg_t ggml_backend_webgpu_reg() { - + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); webgpu_context webgpu_ctx = std::make_shared(); @@ -2333,9 +2485,9 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } ggml_backend_t ggml_backend_webgpu_init(void) { - + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - + return ggml_backend_webgpu_device_init(dev, nullptr); } From 362749910be4f0120c8ffb21ceddeb7d2c088e51 Mon Sep 17 00:00:00 2001 From: James Contini Date: Fri, 10 Oct 2025 13:10:46 -0700 Subject: [PATCH 10/40] removed vestigial files --- ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl | 87 ------------------- .../wgsl-shaders/neg_in_place.wgsl | 84 ------------------ 2 files changed, 171 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl deleted file mode 100644 index 23feb9aa7da..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl +++ /dev/null @@ -1,87 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE": "f32", - } - }, - { - "REPLS": { - "TYPE": "f16", - } - }, -] - -#end(VARIANTS) - -#define(SHADER) -enable f16; - -@group(0) @binding(0) -var src: array<{{TYPE}}>; - -@group(0) @binding(1) -var dst: array<{{TYPE}}>; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32 -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); - i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); - let i2 = i / (params.src_ne1 * params.src_ne0); - i = i % (params.src_ne1 * params.src_ne0); - let i1 = i / params.src_ne0; - let i0 = i % params.src_ne0; - - var j = gid.x; - let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - let j2 = j / (params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne1 * params.dst_ne0); - let j1 = j / params.dst_ne0; - let j0 = j % params.dst_ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + - j2 * params.stride_dst2 + j3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx])); -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl deleted file mode 100644 index 732b56cea23..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl +++ /dev/null @@ -1,84 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE": "f32", - } - }, - { - "REPLS": { - "TYPE": "f16", - } - }, -] - -#end(VARIANTS) - -#define(SHADER) -enable f16; - -@group(0) @binding(0) -var src: array<{{TYPE}}>; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32 -}; - -@group(0) @binding(1) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); - i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); - let i2 = i / (params.src_ne1 * params.src_ne0); - i = i % (params.src_ne1 * params.src_ne0); - let i1 = i / params.src_ne0; - let i0 = i % params.src_ne0; - - var j = gid.x; - let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - let j2 = j / (params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne1 * params.dst_ne0); - let j1 = j / params.dst_ne0; - let j0 = j % params.dst_ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + - j2 * params.stride_dst2 + j3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx])); -} -#end(SHADER) From 74c6add1761a59d2c2ff60b60e8ad3c8300f6d3e Mon Sep 17 00:00:00 2001 From: James Contini Date: Fri, 10 Oct 2025 13:16:48 -0700 Subject: [PATCH 11/40] fixed autoconfig --- ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl index 7b78759dd00..5690850e777 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -208,7 +208,7 @@ "SHADER_NAME": "sigmoid_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, "DECLS": ["INPLACE"] - }, + }, { "SHADER_NAME": "gelu_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" }, @@ -437,7 +437,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { return; } - + var i = gid.x; let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); From 4cf28d7dec41c29186d66152735b244c5699f9dc Mon Sep 17 00:00:00 2001 From: James Contini Date: Sun, 12 Oct 2025 13:32:45 -0700 Subject: [PATCH 12/40] All operators (inlcluding xielu) working --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 55 ++- .../ggml-webgpu/wgsl-shaders/unary_op.wgsl | 318 +++++++++++++----- 2 files changed, 277 insertions(+), 96 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b5ec7816751..46b1b14f428 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -262,7 +262,7 @@ struct webgpu_context_struct { webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split webgpu_pipeline scale_pipeline[2]; // inplace webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace - webgpu_pipeline unary_pipeline[16][2][2]; + webgpu_pipeline unary_pipeline[GGML_UNARY_OP_COUNT][2][2]; size_t memset_bytes_per_thread; @@ -344,6 +344,8 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline_desc.compute.constants = constants.data(); pipeline_desc.compute.constantCount = constants.size(); } + + pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } @@ -867,7 +869,8 @@ static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst, webgpu_pipeline & pipeline, - bool in_place) { + bool in_place, + bool additional_params=false) { uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -885,6 +888,11 @@ static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; + if (additional_params) { + for (uint i = 1; i < 5; i++) { + params.push_back((uint32_t)(ggml_get_op_params_f32(dst, i))); // alpha_n, alpha_p, beta, eps + } + } std::vector entries = { { .binding = 0, @@ -1302,8 +1310,10 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_UNARY: { const ggml_unary_op UNARY_OP = ggml_get_unary_op(node); + int in_place = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place); + bool XIELU = (UNARY_OP == GGML_UNARY_OP_XIELU); + return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place, XIELU); } default: @@ -2023,6 +2033,16 @@ static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1], wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants); + + // XIELU + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0], + wgsl_xielu_f32, "xielu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0], + wgsl_xielu_f16, "xielu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1], + wgsl_xielu_in_place_f32, "xielu_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1], + wgsl_xielu_in_place_f16, "xielu_in_place_f16", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { @@ -2254,9 +2274,36 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = op->type == GGML_TYPE_F32; break; case GGML_OP_UNARY: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && + { + const ggml_unary_op UNARY_OP = ggml_get_unary_op(op); + + switch (UNARY_OP) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_XIELU: + supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && (src1 ? (src1->type == op->type) : true); + break; + case GGML_UNARY_OP_COUNT: + default: + break; + } + } break; + default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl index 5690850e777..f9fe9cb3c4b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -4,22 +4,23 @@ { "SHADER_NAME": "abs_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + }, { "SHADER_NAME": "abs_f16", "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "abs_in_place_f32", "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "abs_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sgn_f32", @@ -27,7 +28,7 @@ "TYPE": "f32", "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sgn_f16", @@ -35,7 +36,7 @@ "TYPE": "f16", "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sgn_in_place_f32", @@ -43,7 +44,7 @@ "TYPE": "f32", "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sgn_in_place_f16", @@ -51,27 +52,27 @@ "TYPE": "f16", "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "neg_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "neg_f16", "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "neg_in_place_f32", "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "neg_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "step_f32", @@ -79,7 +80,7 @@ "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "step_f16", @@ -87,7 +88,7 @@ "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "step_in_place_f32", @@ -95,7 +96,7 @@ "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "step_in_place_f16", @@ -103,27 +104,27 @@ "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "tanh_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, - "DECLS": ["NOT_INPLACE"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "tanh_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, - "DECLS": ["NOT_INPLACE"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "tanh_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, - "DECLS": ["INPLACE"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "tanh_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913));" }, - "DECLS": ["INPLACE"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "elu_f32", @@ -131,7 +132,7 @@ "TYPE": "f32", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "elu_f16", @@ -139,7 +140,7 @@ "TYPE": "f16", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "elu_in_place_f32", @@ -147,7 +148,7 @@ "TYPE": "f32", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "elu_in_place_f16", @@ -155,7 +156,7 @@ "TYPE": "f16", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "relu_f32", @@ -163,7 +164,7 @@ "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "relu_f16", @@ -171,7 +172,7 @@ "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "relu_in_place_f32", @@ -179,7 +180,7 @@ "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "relu_in_place_f16", @@ -187,179 +188,211 @@ "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sigmoid_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sigmoid_f16", "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sigmoid_in_place_f32", "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "sigmoid_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" }, - "DECLS": ["NOT_INPLACE"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" }, - "DECLS": ["NOT_INPLACE"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913)));" }, - "DECLS": ["INPLACE"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913)));" }, - "DECLS": ["INPLACE"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_quick_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" }, - "DECLS": ["NOT_INPLACE"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_quick_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913)));" }, - "DECLS": ["NOT_INPLACE"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_quick_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913)));" }, - "DECLS": ["INPLACE"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_quick_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i])));" }, - "DECLS": ["INPLACE"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "silu_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "silu_f16", "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "silu_in_place_f32", "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "silu_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardswish_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardswish_f16", "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardswish_in_place_f32", "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardswish_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardsigmoid_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardsigmoid_f16", "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardsigmoid_in_place_f32", "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "hardsigmoid_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "exp_f32", "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "exp_f16", "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "exp_in_place_f32", "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "exp_in_place_f16", "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_erf_f32", "REPLS": { "TYPE": "f32", - "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_erf_f16", "REPLS": { "TYPE": "f16", - "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE"] + "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_erf_in_place_f32", "REPLS": { "TYPE": "f32", - "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_DFLT_PARAMS"] }, { "SHADER_NAME": "gelu_erf_in_place_f16", "REPLS": { "TYPE": "f16", - "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913)));" + "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" + }, + "DECLS": ["INPLACE_DFLT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["NOT_INPLACE_EXT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);" + }, + "DECLS": ["NOT_INPLACE_EXT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_in_place_f32", + "REPLS": { + "TYPE": "f32", + "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);" + }, + "DECLS": ["INPLACE_EXT_PARAMS"] + }, + { + "SHADER_NAME": "xielu_in_place_f16", + "REPLS": { + "TYPE": "f16", + "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);" }, - "DECLS": ["INPLACE"] + "DECLS": ["INPLACE_EXT_PARAMS"] } ] @@ -367,11 +400,7 @@ #define(DECLS) -#decl(NOT_INPLACE) - -fn update(dst_i: u32, src_i: u32) { - {{FUNC}} -} +#decl(NOT_INPLACE_DFLT_PARAMS) @group(0) @binding(1) var dst: array<{{TYPE}}>; @@ -379,25 +408,74 @@ var dst: array<{{TYPE}}>; @group(0) @binding(2) var params: Params; -#enddecl(NOT_INPLACE) +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, -#decl(INPLACE) + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, -fn update(dst_i: u32, src_i: u32) { - {{FUNC}} // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458 -} + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +#enddecl(NOT_INPLACE_DFLT_PARAMS) + +#decl(INPLACE_DFLT_PARAMS) @group(0) @binding(1) var params: Params; -#enddecl(INPLACE) +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements -#end(DECLS) + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, -#define(SHADER) + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; -enable f16; +#enddecl(INPLACE_DFLT_PARAMS) + +#decl(NOT_INPLACE_EXT_PARAMS) + +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; struct Params { ne: u32, // total number of elements @@ -422,13 +500,69 @@ struct Params { dst_ne0: u32, dst_ne1: u32, - dst_ne2: u32 + dst_ne2: u32, + + // XIELU params + alpha_n: u32, + alpha_p: u32, + beta: u32, + eps: u32 +}; + +#enddecl(NOT_INPLACE_EXT_PARAMS) + +#decl(INPLACE_EXT_PARAMS) + +@group(0) @binding(1) +var params: Params; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32, + + // XIELU params + alpha_n: u32, + alpha_p: u32, + beta: u32, + eps: u32 }; +#enddecl(INPLACE_EXT_PARAMS) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +fn update(dst_i: u32, src_i: u32) { + {{FUNC}} +} + @group(0) @binding(0) var src: array<{{TYPE}}>; - DECLS override wg_size: u32; From f9282c660c10dec4487d434549bdb707a9cd9f37 Mon Sep 17 00:00:00 2001 From: James Contini Date: Sun, 12 Oct 2025 13:41:41 -0700 Subject: [PATCH 13/40] removed unnecesarry checking if node->src[1] exists for unary operators --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 46b1b14f428..6a4d5202031 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2294,8 +2294,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_XIELU: - supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && - (src1 ? (src1->type == op->type) : true); + supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; case GGML_UNARY_OP_COUNT: default: From 8c70b8fece445cdc9a8c660dbddbf201e52da2bb Mon Sep 17 00:00:00 2001 From: James Contini Date: Wed, 15 Oct 2025 16:14:20 -0700 Subject: [PATCH 14/40] responded and dealt with PR comments --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 258 +++++----- .../ggml-webgpu/wgsl-shaders/unary_op.wgsl | 452 ++++++------------ 2 files changed, 278 insertions(+), 432 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 6a4d5202031..2f4fdc1c3cc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -62,7 +62,6 @@ #define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 32u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u -#define WEBGPU_WAIT_ANY_BATCH_SIZE 64 #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE @@ -252,19 +251,18 @@ struct webgpu_context_struct { webgpu_pipeline set_rows_pipeline; webgpu_pipeline get_rows_pipeline[30]; webgpu_pipeline get_rows_f32_no_vec_pipeline; - webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type - webgpu_pipeline add_pipeline[2][2]; // type, inplace - webgpu_pipeline sub_pipeline[2][2]; // type, inplace - webgpu_pipeline mul_pipeline[2][2]; // type, inplace - webgpu_pipeline div_pipeline[2][2]; // type, inplace - webgpu_pipeline rms_norm_pipeline[2]; // inplace - webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace - webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split - webgpu_pipeline scale_pipeline[2]; // inplace + webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type + webgpu_pipeline add_pipeline[2][2]; // type, inplace + webgpu_pipeline sub_pipeline[2][2]; // type, inplace + webgpu_pipeline mul_pipeline[2][2]; // type, inplace + webgpu_pipeline div_pipeline[2][2]; // type, inplace + webgpu_pipeline rms_norm_pipeline[2]; // inplace + webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace + webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split + webgpu_pipeline scale_pipeline[2]; // inplace webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace webgpu_pipeline unary_pipeline[GGML_UNARY_OP_COUNT][2][2]; - size_t memset_bytes_per_thread; // Staging buffer for reading data from the GPU @@ -345,7 +343,6 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline_desc.compute.constantCount = constants.size(); } - pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } @@ -721,10 +718,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); - } - static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, @@ -865,14 +860,12 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); } -static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * dst, - webgpu_pipeline & pipeline, - bool in_place, - bool additional_params=false) { - - +static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * dst, + webgpu_pipeline & pipeline, + bool in_place, + const std::vector & xielu_params = {}) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -888,18 +881,13 @@ static webgpu_command ggml_webgpu_unary_op( webgpu_context & ctx, (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; - if (additional_params) { - for (uint i = 1; i < 5; i++) { - params.push_back((uint32_t)(ggml_get_op_params_f32(dst, i))); // alpha_n, alpha_p, beta, eps - } - } + params.insert(params.end(), xielu_params.begin(), xielu_params.end()); std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - }; if (!in_place) { entries.push_back({ .binding = 1, @@ -1258,8 +1246,6 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * src1 = node->src[1]; ggml_tensor * src2 = node->src[2]; - - switch (node->op) { // no-ops case GGML_OP_NONE: @@ -1309,11 +1295,24 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: { - const ggml_unary_op UNARY_OP = ggml_get_unary_op(node); + const ggml_unary_op UNARY_OP = ggml_get_unary_op(node); + int in_place = ggml_webgpu_tensor_equal(src0, node); + std::vector xielu_params; - int in_place = ggml_webgpu_tensor_equal(src0, node); - bool XIELU = (UNARY_OP == GGML_UNARY_OP_XIELU); - return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], in_place, XIELU); + switch (UNARY_OP) { + case GGML_UNARY_OP_XIELU: + xielu_params = { + static_cast(ggml_get_op_params_f32(node, 1)), // alpha_n + static_cast(ggml_get_op_params_f32(node, 2)), // alpha_p + static_cast(ggml_get_op_params_f32(node, 3)), // beta + static_cast(ggml_get_op_params_f32(node, 4)) // eps + }; + break; + default: + break; + } + return ggml_webgpu_unary_op(ctx, src0, node, ctx->unary_pipeline[UNARY_OP][node->type][in_place], + in_place, xielu_params); } default: @@ -1322,7 +1321,6 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); @@ -1541,8 +1539,6 @@ static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer(" << size << ")"); ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); @@ -1554,8 +1550,6 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b ggml_backend_webgpu_buffer_context * buf_ctx = new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf); - - return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } @@ -1886,163 +1880,179 @@ static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { // ABS ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0], - wgsl_abs_f32, "abs_f32", constants); + wgsl_abs_f32, "abs_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0], - wgsl_abs_f16, "abs_f16", constants); + wgsl_abs_f16, "abs_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1], - wgsl_abs_in_place_f32, "abs_in_place_f32", constants); + wgsl_abs_in_place_f32, "abs_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1], - wgsl_abs_in_place_f16, "abs_in_place_f16", constants); + wgsl_abs_in_place_f16, "abs_in_place_f16", constants); // SGN ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0], - wgsl_sgn_f32, "sgn_f32", constants); + wgsl_sgn_f32, "sgn_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0], - wgsl_sgn_f16, "sgn_f16", constants); + wgsl_sgn_f16, "sgn_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1], - wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants); + wgsl_sgn_in_place_f32, "sgn_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1], - wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants); + wgsl_sgn_in_place_f16, "sgn_in_place_f16", constants); // NEG ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0], - wgsl_neg_f32, "neg_f32", constants); + wgsl_neg_f32, "neg_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0], - wgsl_neg_f16, "neg_f16", constants); + wgsl_neg_f16, "neg_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1], - wgsl_neg_in_place_f32, "neg_in_place_f32", constants); + wgsl_neg_in_place_f32, "neg_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1], - wgsl_neg_in_place_f16, "neg_in_place_f16", constants); + wgsl_neg_in_place_f16, "neg_in_place_f16", constants); // STEP ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0], - wgsl_step_f32, "step_f32", constants); + wgsl_step_f32, "step_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0], - wgsl_step_f16, "step_f16", constants); + wgsl_step_f16, "step_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1], - wgsl_step_in_place_f32, "step_in_place_f32", constants); + wgsl_step_in_place_f32, "step_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1], - wgsl_step_in_place_f16, "step_in_place_f16", constants); + wgsl_step_in_place_f16, "step_in_place_f16", constants); // TANH ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0], - wgsl_tanh_f32, "tanh_f32", constants); + wgsl_tanh_f32, "tanh_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0], - wgsl_tanh_f16, "tanh_f16", constants); + wgsl_tanh_f16, "tanh_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1], - wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants); + wgsl_tanh_in_place_f32, "tanh_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1], - wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants); + wgsl_tanh_in_place_f16, "tanh_in_place_f16", constants); // ELU ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0], - wgsl_elu_f32, "elu_f32", constants); + wgsl_elu_f32, "elu_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0], - wgsl_elu_f16, "elu_f16", constants); + wgsl_elu_f16, "elu_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1], - wgsl_elu_in_place_f32, "elu_in_place_f32", constants); + wgsl_elu_in_place_f32, "elu_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1], - wgsl_elu_in_place_f16, "elu_in_place_f16", constants); + wgsl_elu_in_place_f16, "elu_in_place_f16", constants); // RELU ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0], - wgsl_relu_f32, "relu_f32", constants); + wgsl_relu_f32, "relu_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0], - wgsl_relu_f16, "relu_f16", constants); + wgsl_relu_f16, "relu_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1], - wgsl_relu_in_place_f32, "relu_in_place_f32", constants); + wgsl_relu_in_place_f32, "relu_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1], - wgsl_relu_in_place_f16, "relu_in_place_f16", constants); + wgsl_relu_in_place_f16, "relu_in_place_f16", constants); // SIGMOID ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0], - wgsl_sigmoid_f32, "sigmoid_f32", constants); + wgsl_sigmoid_f32, "sigmoid_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0], - wgsl_sigmoid_f16, "sigmoid_f16", constants); + wgsl_sigmoid_f16, "sigmoid_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1], - wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants); + wgsl_sigmoid_in_place_f32, "sigmoid_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1], - wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants); + wgsl_sigmoid_in_place_f16, "sigmoid_in_place_f16", constants); // GELU ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0], - wgsl_gelu_f32, "gelu_f32", constants); + wgsl_gelu_f32, "gelu_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0], - wgsl_gelu_f16, "gelu_f16", constants); + wgsl_gelu_f16, "gelu_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1], - wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants); + wgsl_gelu_in_place_f32, "gelu_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1], - wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants); + wgsl_gelu_in_place_f16, "gelu_in_place_f16", constants); // GELU_QUICK - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0], - wgsl_gelu_quick_f32, "gelu_quick_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0], - wgsl_gelu_quick_f16, "gelu_quick_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1], - wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1], - wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0], + wgsl_gelu_quick_f32, "gelu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0], + wgsl_gelu_quick_f16, "gelu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1], + wgsl_gelu_quick_in_place_f32, "gelu_quick_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1], + wgsl_gelu_quick_in_place_f16, "gelu_quick_in_place_f16", constants); // SILU ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0], - wgsl_silu_f32, "silu_f32", constants); + wgsl_silu_f32, "silu_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0], - wgsl_silu_f16, "silu_f16", constants); + wgsl_silu_f16, "silu_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1], - wgsl_silu_in_place_f32, "silu_in_place_f32", constants); + wgsl_silu_in_place_f32, "silu_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1], - wgsl_silu_in_place_f16, "silu_in_place_f16", constants); + wgsl_silu_in_place_f16, "silu_in_place_f16", constants); // HARDSWISH - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0], - wgsl_hardswish_f32, "hardswish_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0], - wgsl_hardswish_f16, "hardswish_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1], - wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1], - wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0], + wgsl_hardswish_f32, "hardswish_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0], + wgsl_hardswish_f16, "hardswish_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1], + wgsl_hardswish_in_place_f32, "hardswish_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1], + wgsl_hardswish_in_place_f16, "hardswish_in_place_f16", constants); // HARDSIGMOID - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0], - wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0], - wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1], - wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1], - wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0], + wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0], + wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1], + wgsl_hardsigmoid_in_place_f32, "hardsigmoid_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1], + wgsl_hardsigmoid_in_place_f16, "hardsigmoid_in_place_f16", constants); // EXP ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0], - wgsl_exp_f32, "exp_f32", constants); + wgsl_exp_f32, "exp_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0], - wgsl_exp_f16, "exp_f16", constants); + wgsl_exp_f16, "exp_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1], - wgsl_exp_in_place_f32, "exp_in_place_f32", constants); + wgsl_exp_in_place_f32, "exp_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1], - wgsl_exp_in_place_f16, "exp_in_place_f16", constants); + wgsl_exp_in_place_f16, "exp_in_place_f16", constants); // GELU_ERF - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0], - wgsl_gelu_erf_f32, "gelu_erf_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0], - wgsl_gelu_erf_f16, "gelu_erf_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1], - wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1], - wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0], wgsl_gelu_erf_f32, + "gelu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0], wgsl_gelu_erf_f16, + "gelu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1], + wgsl_gelu_erf_in_place_f32, "gelu_erf_in_place_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, + webgpu_ctx->unary_pipeline[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1], + wgsl_gelu_erf_in_place_f16, "gelu_erf_in_place_f16", constants); // XIELU ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0], - wgsl_xielu_f32, "xielu_f32", constants); + wgsl_xielu_f32, "xielu_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0], - wgsl_xielu_f16, "xielu_f16", constants); + wgsl_xielu_f16, "xielu_f16", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1], - wgsl_xielu_in_place_f32, "xielu_in_place_f32", constants); + wgsl_xielu_in_place_f32, "xielu_in_place_f32", constants); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->unary_pipeline[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1], - wgsl_xielu_in_place_f16, "xielu_in_place_f16", constants); + wgsl_xielu_in_place_f16, "xielu_in_place_f16", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { @@ -2084,7 +2094,6 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { - GGML_UNUSED(params); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); @@ -2103,7 +2112,6 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co /* .device = */ dev, /* .context = */ &backend_ctx, }; - //tried return &backend; } @@ -2160,7 +2168,6 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { } static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); webgpu_context webgpu_ctx = ctx->webgpu_ctx; @@ -2294,9 +2301,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_XIELU: - supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + supports_op = supports_op = + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; - case GGML_UNARY_OP_COUNT: default: break; } @@ -2448,7 +2455,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - ggml_webgpu_init_memset_pipeline(ctx); ggml_webgpu_init_mul_mat_pipeline(ctx); ggml_webgpu_init_set_rows_pipeline(ctx); @@ -2505,7 +2511,6 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { /* End GGML Backend Registration Interface */ ggml_backend_reg_t ggml_backend_webgpu_reg() { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); webgpu_context webgpu_ctx = std::make_shared(); @@ -2531,7 +2536,6 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } ggml_backend_t ggml_backend_webgpu_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); return ggml_backend_webgpu_device_init(dev, nullptr); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl index f9fe9cb3c4b..7f632a24e53 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -3,396 +3,339 @@ [ { "SHADER_NAME": "abs_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] - + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "abs_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "abs_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "abs_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "sgn_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "sgn_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "sgn_in_place_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "sgn_in_place_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "neg_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "neg_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "neg_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "neg_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "step_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "step_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "step_in_place_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "step_in_place_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "tanh_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "tanh_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "tanh_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "tanh_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "elu_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "elu_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "elu_in_place_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "elu_in_place_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "relu_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "relu_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "relu_in_place_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "relu_in_place_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "sigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "sigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "sigmoid_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "sigmoid_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "gelu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "gelu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "gelu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "gelu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "gelu_quick_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "gelu_quick_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "gelu_quick_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "gelu_quick_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "silu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "silu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "silu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "silu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "hardswish_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "hardswish_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "hardswish_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "hardswish_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "hardsigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "hardsigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "hardsigmoid_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "hardsigmoid_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "exp_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "exp_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);" }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "exp_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "exp_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);" }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "gelu_erf_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "gelu_erf_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" - }, - "DECLS": ["NOT_INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "gelu_erf_in_place_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "gelu_erf_in_place_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" - }, - "DECLS": ["INPLACE_DFLT_PARAMS"] + "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "xielu_f32", "REPLS": { "TYPE": "f32", - "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);" + "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);", + "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" }, - "DECLS": ["NOT_INPLACE_EXT_PARAMS"] + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "xielu_f16", "REPLS": { "TYPE": "f16", - "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);" + "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);", + "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" }, - "DECLS": ["NOT_INPLACE_EXT_PARAMS"] + "DECLS": ["NOT_INPLACE"] }, { "SHADER_NAME": "xielu_in_place_f32", "REPLS": { "TYPE": "f32", - "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);" + "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);", + "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" }, - "DECLS": ["INPLACE_EXT_PARAMS"] + "DECLS": ["INPLACE"] }, { "SHADER_NAME": "xielu_in_place_f16", "REPLS": { "TYPE": "f16", - "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);" + "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);", + "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" }, - "DECLS": ["INPLACE_EXT_PARAMS"] + "DECLS": ["INPLACE"] } ] @@ -400,7 +343,7 @@ #define(DECLS) -#decl(NOT_INPLACE_DFLT_PARAMS) +#decl(INPLACE) @group(0) @binding(1) var dst: array<{{TYPE}}>; @@ -408,68 +351,9 @@ var dst: array<{{TYPE}}>; @group(0) @binding(2) var params: Params; -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, +#enddecl(INPLACE) - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32 -}; - -#enddecl(NOT_INPLACE_DFLT_PARAMS) - -#decl(INPLACE_DFLT_PARAMS) - -@group(0) @binding(1) -var params: Params; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32 -}; - -#enddecl(INPLACE_DFLT_PARAMS) - -#decl(NOT_INPLACE_EXT_PARAMS) +#decl(NOT_INPLACE) @group(0) @binding(1) var dst: array<{{TYPE}}>; @@ -477,44 +361,22 @@ var dst: array<{{TYPE}}>; @group(0) @binding(2) var params: Params; -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, +#enddecl(NOT_INPLACE) - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, +#end(DECLS) - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32, +#define(SHADER) - // XIELU params - alpha_n: u32, - alpha_p: u32, - beta: u32, - eps: u32 -}; +enable f16; -#enddecl(NOT_INPLACE_EXT_PARAMS) +fn update(dst_i: u32, src_i: u32) { + {{FUNC}} +} -#decl(INPLACE_EXT_PARAMS) +@group(0) @binding(0) +var src: array<{{TYPE}}>; -@group(0) @binding(1) -var params: Params; +DECLS struct Params { ne: u32, // total number of elements @@ -541,30 +403,9 @@ struct Params { dst_ne1: u32, dst_ne2: u32, - // XIELU params - alpha_n: u32, - alpha_p: u32, - beta: u32, - eps: u32 + {{EXT_PARAMS}} }; -#enddecl(INPLACE_EXT_PARAMS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -fn update(dst_i: u32, src_i: u32) { - {{FUNC}} -} - -@group(0) @binding(0) -var src: array<{{TYPE}}>; - -DECLS - override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { @@ -599,3 +440,4 @@ fn main(@builtin(global_invocation_id) gid: vec3) { } #end(SHADER) + From e1f6baea31645e5d96ad53664acae856f74b96f4 Mon Sep 17 00:00:00 2001 From: James Contini Date: Wed, 29 Oct 2025 23:08:37 -0700 Subject: [PATCH 15/40] implemented REPL_Template support and removed bug in unary operators kernel --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 16 + .../ggml-webgpu/wgsl-shaders/unary_op.wgsl | 696 +++++++++--------- 3 files changed, 375 insertions(+), 341 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 2f4fdc1c3cc..0fc2691cc4f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -865,7 +865,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * dst, webgpu_pipeline & pipeline, bool in_place, - const std::vector & xielu_params = {}) { + const std::vector & extra_params = {}) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -881,7 +881,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; - params.insert(params.end(), xielu_params.begin(), xielu_params.end()); + params.insert(params.end(), extra_params.begin(), extra_params.end()); std::vector entries = { { .binding = 0, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 251051eaeca..7de19bef77f 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -18,6 +18,14 @@ def parse_decls(decls_text): decls[name.strip()] = code.strip() return decls +def replace_repl_placeholders(variant, template_map): + + for repl, code in variant["REPLS"].items(): + for key, val in template_map.items(): + # Match "key" and avoid matching subsequences using by using \b + code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) + variant["REPLS"][repl] = code + return variant def replace_placeholders(shader_text, replacements): for key, val in replacements.items(): @@ -71,6 +79,10 @@ def generate_variants(fname, input_dir, output_dir, outfile): decls_map = parse_decls(extract_block(text, "DECLS")) except ValueError: decls_map = {} + try: + templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) + except ValueError: + templates_map = {} with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f: common_decls = f.read() @@ -85,11 +97,15 @@ def generate_variants(fname, input_dir, output_dir, outfile): decls_code = "" for key in decls: if key not in decls_map: + raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) + if "REPLS" in variant: + variant = replace_repl_placeholders(variant, templates_map) + final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl index 7f632a24e53..93f0fac66f1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -1,342 +1,363 @@ +#define(REPL_TEMPLATES) + +{ + "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);", + "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);", + "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);", + "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];", + "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));", + "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", + "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", + "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", + "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", + "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", + "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);", + "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", + "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" +} + +#end(REPL_TEMPLATES) + #define(VARIANTS) [ - { - "SHADER_NAME": "abs_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "abs_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = abs(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sgn_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(select(0.0, -1.0, src[src_i] < 0.0), 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sgn_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(select(0.0h, -1.0h, src[src_i] < 0.0h), 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "neg_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "neg_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = -src[src_i];", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "step_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, 1.0, src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "step_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, 1.0h, src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "tanh_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "tanh_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "elu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "elu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(exp(src[src_i]) - 1.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "relu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "relu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = select(0.0h, src[src_i], src[src_i] > 0.0h);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 1.0h / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(sqrt(2.0h / 3.14159265h) * (src[src_i] + 0.044715h * pow(src[src_i], 3.0h)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(clamp(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * src[src_i] * (1.0 + 0.044715 * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * 0.5h * (1.0h + tanh(0.79788456h * src[src_i] * (1.0h + 0.044715h * src[src_i] * src[src_i]))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "silu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "silu_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] / (1.0h + exp(-src[src_i]));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardswish_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardswish_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = src[src_i] * min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = min(1.0h, max(0.0h, (src[src_i] + 3.0h) / 6.0h));", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "exp_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "exp_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = exp(src[src_i]);", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f32", - "REPLS": { "TYPE": "f32", "FUNC": "dst[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f16", - "REPLS": { "TYPE": "f16", "FUNC": "dst[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_in_place_f32", - "REPLS": { "TYPE": "f32", "FUNC": "src[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_in_place_f16", - "REPLS": { "TYPE": "f16", "FUNC": "src[dst_i] = 0.5h * src[src_i] * (1.0h + tanh(clamp(0.79788456h * (src[src_i] + 0.044715h * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "EXT_PARAMS": "" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "xielu_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "dst[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_in_place_f32", - "REPLS": { - "TYPE": "f32", - "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f32(params.eps))) - 1.0) - src[src_i]) * f32(params.alpha_n) + f32(params.beta) * src[src_i], f32(params.alpha_p) * src[src_i] * src[src_i] + f32(params.beta) * src[src_i], src[src_i] > 0.0);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "xielu_in_place_f16", - "REPLS": { - "TYPE": "f16", - "FUNC": "src[dst_i] = select(((exp(min(src[src_i], f16(params.eps))) - 1.0h) - src[src_i]) * f16(params.alpha_n) + f16(params.beta) * src[src_i], f16(params.alpha_p) * src[src_i] * src[src_i] + f16(params.beta) * src[src_i], src[src_i] > 0.0h);", - "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32" - }, - "DECLS": ["INPLACE"] - } + { + "SHADER_NAME": "abs_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "abs_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "abs_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "abs_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "sgn_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sgn_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sgn_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "neg_f32", + "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_f16", + "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "neg_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "neg_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "step_f32", + "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_f16", + "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "step_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "step_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "tanh_f32", + "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_f16", + "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "tanh_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "tanh_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "elu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "elu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "elu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "relu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "relu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "relu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "sigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "silu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "silu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "silu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "exp_f32", + "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_f16", + "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "exp_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "exp_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "hardsigmoid_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardsigmoid_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "hardswish_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "hardswish_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "hardswish_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "gelu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "gelu_quick_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_quick_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "xielu_f32", + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "xielu_f16", + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "xielu_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "xielu_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: u32, alpha_p: u32, beta: u32, eps: u32", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f32", + "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "gelu_erf_in_place_f16", + "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + } ] #end(VARIANTS) @@ -346,9 +367,6 @@ #decl(INPLACE) @group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) var params: Params; #enddecl(INPLACE) From c41a1cb54f70d964cef1a8da5fc0d41b9390c065 Mon Sep 17 00:00:00 2001 From: James Contini Date: Wed, 29 Oct 2025 23:13:06 -0700 Subject: [PATCH 16/40] formatted embed wgsl and ggml-webgpu.cpp --- ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 7de19bef77f..b60eeb4df5d 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -19,7 +19,6 @@ def parse_decls(decls_text): return decls def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): for key, val in template_map.items(): # Match "key" and avoid matching subsequences using by using \b @@ -97,15 +96,13 @@ def generate_variants(fname, input_dir, output_dir, outfile): decls_code = "" for key in decls: if key not in decls_map: - raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) - if "REPLS" in variant: variant = replace_repl_placeholders(variant, templates_map) final_shader = replace_placeholders(final_shader, variant["REPLS"]) + # second run to expand placeholders in repl_template final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) From c6bc12599fe9651d9e9458ba89e0efac4d47c43e Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 5 Nov 2025 08:24:14 -0800 Subject: [PATCH 17/40] Faster tensors (#8) Add fast matrix and matrix/vector multiplication. --- .github/workflows/build.yml | 18 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 326 +++++++++++++++++- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 9 +- .../wgsl-shaders/mul_mat.tmpl.wgsl | 10 +- .../wgsl-shaders/mul_mat_decls.tmpl | 97 ++++++ .../wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl | 247 +++++++++++++ .../mul_mat_subgroup_matrix.tmpl.wgsl | 302 ++++++++++++++++ .../wgsl-shaders/mul_mat_vec.tmpl.wgsl | 267 ++++++++++++++ 8 files changed, 1247 insertions(+), 29 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 15e11330952..36084c55078 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -161,15 +161,16 @@ jobs: - name: Dawn Dependency id: dawn-depends run: | - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build @@ -521,15 +522,16 @@ jobs: id: dawn-depends run: | sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev - DAWN_VERSION="v1.0.0" + DAWN_VERSION="v2.0.0" DAWN_OWNER="reeselevine" DAWN_REPO="dawn" - DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz" + DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip" echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" - curl -L -o artifact.tar.gz \ + curl -L -o artifact.zip \ "https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}" mkdir dawn - tar -xvf artifact.tar.gz -C dawn --strip-components=1 + unzip artifact.zip + tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1 - name: Build id: cmake_build diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1a157567315..2b9c28bc3b6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -73,6 +74,30 @@ // For operations which process a row in parallel, this seems like a reasonable default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_K 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -236,6 +261,10 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + bool supports_subgroup_matrix = false; + uint32_t subgroup_size; + wgpu::SubgroupMatrixConfig subgroup_matrix_config; + // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. uint32_t max_wg_size_x; @@ -247,6 +276,11 @@ struct webgpu_context_struct { webgpu_buf_pool set_rows_error_buf_pool; webgpu_pipeline memset_pipeline; + + std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::map>> + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + webgpu_pipeline mul_mat_pipeline[30][2]; webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized webgpu_pipeline get_rows_pipeline[30]; @@ -321,6 +355,25 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ +// Process a WGSL shader string, replacing tokens of the form {{KEY}} with +// the corresponding values provided in `repls`. +static std::string ggml_webgpu_process_shader_repls(const char * src, + const std::vector> & repls) { + if (!src) { + return std::string(); + } + std::string s = src; + for (const auto & kv : repls) { + std::string token = "{{" + kv.first + "}}"; + size_t pos = 0; + while ((pos = s.find(token, pos)) != std::string::npos) { + s.replace(pos, token.length(), kv.second); + pos += kv.second.length(); + } + } + return s; +} + static void ggml_webgpu_create_pipeline(wgpu::Device & device, webgpu_pipeline & pipeline, const char * shader_code, @@ -346,6 +399,30 @@ static void ggml_webgpu_create_pipeline(wgpu::Device & pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; } +static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code; + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label; + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + if (constants.size() > 0) { + pipeline_desc.compute.constants = constants.data(); + pipeline_desc.compute.constantCount = constants.size(); + } + return { device.CreateComputePipeline(&pipeline_desc), label }; +} + static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::Buffer & buffer, size_t size, @@ -512,6 +589,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, + uint32_t wg_y = 1, std::optional set_rows_error_bufs = std::nullopt) { webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); @@ -557,7 +635,7 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & #endif pass.SetPipeline(pipeline.pipeline); pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups(wg_x, 1, 1); + pass.DispatchWorkgroups(wg_x, wg_y, 1); pass.End(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -779,7 +857,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, error_bufs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -835,8 +913,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[1], // number of rows in result (M) - (uint32_t) dst->ne[0], // number of columns in result (N) + (uint32_t) dst->ne[0], // number of rows in result (M, transposed) + (uint32_t) dst->ne[1], // number of columns in result (N) (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 @@ -865,9 +943,67 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; + webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type]; + uint32_t wg_x = (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; - return ggml_backend_webgpu_build(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x); + uint32_t wg_y = 1; + + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + use_fast = true; + break; + default: + break; + } + break; + default: + break; + } + + if (use_fast) { + int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; + if (dst->ne[1] == 1) { + // We don't support vectorized mul_mat_vec for quantized types + vectorized = vectorized && (src0->type < 2); + pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = + (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / + ctx->limits.maxComputeWorkgroupsPerDimension; + } else { + pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; + uint32_t wg_m; + uint32_t wg_n; + if (ctx->supports_subgroup_matrix) { + // The total number of subgroups/workgroups needed per matrix. + uint32_t wg_m_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; + wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; + uint32_t wg_n_sg_tile = + WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; + wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; + } else { + uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; + uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; + wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; + wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } + } + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1583,12 +1719,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], - wgsl_mul_mat_f32_f32, "mul_mat_f32_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], - wgsl_mul_mat_f16_f16, "mul_mat_f16_f16"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], - wgsl_mul_mat_f16_f32, "mul_mat_f16_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], @@ -1627,6 +1757,136 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + + if (webgpu_ctx->supports_subgroup_matrix) { + std::vector> sg_matrix_repls; + sg_matrix_repls.emplace_back("WEBGPU_MAX_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size)); + sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M)); + sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N)); + sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_M_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.M)); + sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_N_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.N)); + sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_K_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.K)); + + std::string proc_mul_mat_subgroup_matrix_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); + std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f32_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f32_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), + "mul_mat_subgroup_matrix_f16_f16_vec"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), + "mul_mat_subgroup_matrix_q4_0_f32_vec"); + } else { + std::vector mul_mat_reg_tile_constants(3); + mul_mat_reg_tile_constants[0].key = "TILE_K"; + mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; + mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; + mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; + mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; + mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; + + std::vector> reg_repls; + reg_repls.emplace_back("WEBGPU_TILE_M", std::to_string(WEBGPU_MUL_MAT_TILE_M)); + reg_repls.emplace_back("WEBGPU_TILE_N", std::to_string(WEBGPU_MUL_MAT_TILE_N)); + + // Process each reg-tile shader with tile replacements. + // Keep the processed strings in-scope so .c_str() remains valid. + std::string proc_mul_mat_reg_tile_f32_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f32_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + std::string proc_mul_mat_reg_tile_f16_f16_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32 = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + std::string proc_mul_mat_reg_tile_q4_0_f32_vec = + ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), + "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), + "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), + "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), + "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), + "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), + "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), + "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), + "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + } + + std::vector mul_mat_vec_constants(3); + mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; + mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; + mul_mat_vec_constants[1].key = "TILE_K"; + mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; + mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; + mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { @@ -2124,7 +2384,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; - wgpu::RequestAdapterOptions options = {}; + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + wgpu::RequestAdapterOptions options = {}; + options.nextInChain = &adapterTogglesDesc; ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2140,12 +2406,46 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } ctx->adapter.GetInfo(&info); + wgpu::SupportedFeatures features; + ctx->adapter.GetFeatures(&features); + // we require f16 support + GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + + // Only support square f16 matrices of size 8 or 16 for now + bool valid_subgroup_matrix_config = false; + if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->subgroup_matrix_config = config; + valid_subgroup_matrix_config = true; + break; + } + } + } + + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + ctx->subgroup_size = info.subgroupMaxSize; + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16, wgpu::FeatureName::ImplicitDeviceSynchronization }; + if (ctx->supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::Subgroups); + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 251051eaeca..ed8068d416e 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -72,9 +72,12 @@ def generate_variants(fname, input_dir, output_dir, outfile): except ValueError: decls_map = {} - with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f: - common_decls = f.read() - decls_map.update(parse_decls(common_decls)) + for fname in sorted(os.listdir(input_dir)): + if fname.endswith(".tmpl"): + tmpl_path = os.path.join(input_dir, fname) + with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: + decls = f_tmpl.read() + decls_map.update(parse_decls(decls)) shader_template = extract_block(text, "SHADER") for variant in variants: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl index 141db9b39d9..0f8e6e5ac3d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl @@ -864,8 +864,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // N rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed) +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; @@ -891,8 +891,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let dst2_rem = dst3_rem % dst2_stride; - let row = dst2_rem / params.n; // output row - let col = dst2_rem % params.n; // output column + let row = dst2_rem / params.m; // output row + let col = dst2_rem % params.m; // output column let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; @@ -901,7 +901,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { sum += multiply_add(src0_idx_base, src1_idx_base, i); } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum; + dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; } #end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl new file mode 100644 index 00000000000..109ff8d6159 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -0,0 +1,97 @@ +#decl(SHMEM_VEC) +fn store_shmem(val: vec4, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; +} +#enddecl(SHMEM_VEC) + +#decl(SHMEM_SCALAR) +fn store_shmem(val: f16, idx: u32) { + shmem[idx] = val; +} +#enddecl(SHMEM_SCALAR) + +#decl(INIT_SRC0_SHMEM_FLOAT) + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let src0_val = select( // taking a slight performance hit to avoid oob + {{SRC0_TYPE}}(0.0), + src0[src0_idx/{{VEC_SIZE}}], + global_m < params.m && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + } +} + +#enddecl(INIT_SRC0_SHMEM_FLOAT) + +#decl(INIT_SRC1_SHMEM) + +fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_n = offset_n + tile_n; + let global_k = k_outer + tile_k; + let src1_idx = batch_offset + global_n * params.stride_11 + global_k; + let src1_val = select( + {{SRC1_TYPE}}(0.0), + src1[src1_idx/{{VEC_SIZE}}], + global_n < params.n && global_k < params.k); + store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + } +} + +#enddecl(INIT_SRC1_SHMEM) + +#decl(INIT_SRC0_SHMEM_Q4_0) + +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1u + block_offset + j]; + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + shmem[shmem_idx + j * 2 + k] = q_lo; + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + } + } + } + } +} + +#enddecl(INIT_SRC0_SHMEM_Q4_0) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl new file mode 100644 index 00000000000..6b1dd26cd9e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl @@ -0,0 +1,247 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +} +#enddecl(VEC) + +#decl(SCALAR) +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return f32(acc[tm][tn]); +} +#enddecl(SCALAR) + +#end(DECLS) + +#define(SHADER) +enable f16; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +// TILE_M must be multiple of 4 for vec4 loads +const TILE_M = {{WEBGPU_TILE_M}}u; +const TILE_N = {{WEBGPU_TILE_N}}u; + +override WORKGROUP_SIZE_M: u32; +override WORKGROUP_SIZE_N: u32; +override TILE_K: u32; + +override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + var acc: array, TILE_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tm][tn] += src0_tile[tm] * src1_val; + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tn = 0u; tn < TILE_N; tn++) { + let global_col = output_col_base + tn; + if (global_col < params.n) { + for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); + } + } + } + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl new file mode 100644 index 00000000000..47c8ce36ab3 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl @@ -0,0 +1,302 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32_vec", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "vec4", + "DST_TYPE" : "vec4", + "SHMEM_TYPE" : "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE" : "f32", + "SHMEM_TYPE" : "f16", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = vec4( + f32(shmem[shmem_idx]), + f32(shmem[shmem_idx + 1]), + f32(shmem[shmem_idx + 2]), + f32(shmem[shmem_idx + 3]) + ); +} +#enddecl(VEC) + +#decl(SCALAR) +fn store_dst(shmem_idx: u32, dst_idx: u32) { + dst[dst_idx] = f32(shmem[shmem_idx]); +} +#enddecl(SCALAR) + +#end(DECLS) + +#define(SHADER) +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +DECLS + +// Note: These are string interpolated at build time, cannot use override constants due to limitations in +// current Dawn version type definitions/matrix load requirements for constant memory sizes. +const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; +const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; +// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the +// runtime subgroup size is smaller. +const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; + +const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; + +const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; +const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; +const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; + +const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; +const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; + +const TILE_K = {{WEBGPU_TILE_K}}u; + +const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; +const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + +const SG_MAT_ACCUM_SHMEM = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_M_SIZE * SUBGROUP_MATRIX_N_SIZE; + +// We reuse shmem for accumulation matrices +const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM); + +var shmem: array; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32) { + + let thread_id = local_id.x; + let subgroup_m = subgroup_id % SUBGROUP_M; + let subgroup_n = subgroup_id / SUBGROUP_M; + + let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE; + let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; + let wg_per_matrix = wg_m_count * wg_n_count; + + let batch_idx = wg_id.x / wg_per_matrix; + + let wg_in_batch = wg_id.x % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let offset_n = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + var acc_sg_mat : array, SUBGROUP_MATRIX_N>, SUBGROUP_MATRIX_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + if (subgroup_id < EXPECTED_SUBGROUPS) { + + for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) { + + let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner; + var src0_sg_mats: array, SUBGROUP_MATRIX_M>; + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + src0_sg_mats[m] = subgroupMatrixLoad>( + &shmem, + src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K, + false, + TILE_K + ); + } + + let src1_shmem_idx_base = TILE_SRC0_SHMEM + subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner; + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + let src1_sg_mat = subgroupMatrixLoad>( + &shmem, + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K, + true, + TILE_K + ); + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]); + } + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + // Stage the subgroup matrix tiles into shared memory + // This uses WG_M_SG_TILE_SIZE as the stride (number of columns in the workgroup tile). + let WG_TILE_STRIDE = WG_M_SG_TILE_SIZE; + let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + + if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :( + for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) { + for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) { + let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE; + let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE; + let out_base = local_row * WG_TILE_STRIDE + local_col; + subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE); + } + } + } + + workgroupBarrier(); + + // Cooperative write: iterate over the entire workgroup tile + let tile_rows = WG_N_SG_TILE_SIZE; + let tile_cols = WG_M_SG_TILE_SIZE; + let total_tile_elems = tile_rows * tile_cols; + let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; + let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; + + for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + let local_row = idx % WG_TILE_STRIDE; + let local_col = idx / WG_TILE_STRIDE; + + let global_row = tile_dst_row_base + local_row; + let global_col = tile_dst_col_base + local_col; + + if (global_col < params.n && global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + store_dst(idx, dst_idx/{{VEC_SIZE}}); + } + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl new file mode 100644 index 00000000000..ffbb6403285 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl @@ -0,0 +1,267 @@ +#define(VARIANTS) +[ + { + "SHADER_SUFFIX": "f32_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f32_f32", + "REPLS": { + "SRC0_TYPE" : "f32", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f32_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f16_vec", + "REPLS": { + "SRC0_TYPE" : "vec4", + "SRC1_TYPE" : "vec4", + "DST_TYPE": "vec4", + "VEC_SIZE" : 4, + }, + "DECLS": ["VEC", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "f16_f16", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f16", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] + }, + { + "SHADER_SUFFIX": "q4_0_f32", + "REPLS": { + "SRC0_TYPE" : "f16", + "SRC1_TYPE" : "f32", + "DST_TYPE": "f32", + "VEC_SIZE" : 1, + }, + "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(VEC) +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); +} + +fn store_val(group_base: u32) -> vec4 { + return vec4(partial_sums[group_base], + partial_sums[group_base + THREADS_PER_OUTPUT], + partial_sums[group_base + THREADS_PER_OUTPUT * 2], + partial_sums[group_base + THREADS_PER_OUTPUT * 3]); +} +#enddecl(VEC) + +#decl(SCALAR) +fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { + return f32(src0_val) * f32(src1_val); +} + +fn store_val(group_base: u32) -> f32 { + return partial_sums[group_base]; +} +#enddecl(SCALAR) + +#decl(MUL_ACC_FLOAT) + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { + let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; + let b = shared_vector[i / {{VEC_SIZE}}]; + local_sum += inner_dot(a, b); + } + return local_sum; +} + +#enddecl(MUL_ACC_FLOAT) + +#decl(MUL_ACC_Q4_0) + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xF) - 8.0) * d; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} + +#enddecl(MUL_ACC_Q4_0) + +#end(DECLS) + +#define(SHADER) +enable f16; + +DECLS + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) +@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) +@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) + +@group(0) @binding(3) var params: MulMatParams; + +override WORKGROUP_SIZE: u32; +override TILE_K: u32; +override OUTPUTS_PER_WG: u32; +override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; + +// Shared memory for collaborative loading and reduction +var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile +var partial_sums: array; // For reduction + +@compute @workgroup_size(WORKGROUP_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let thread_id = local_id.x; + + // Handle batch dimensions + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + // Which of the outputs does this thread belong to? + let thread_group = thread_id / THREADS_PER_OUTPUT; + let thread_in_group = thread_id % THREADS_PER_OUTPUT; + + // Each workgroup computes OUTPUTS_PER_WG consecutive outputs + let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + + var local_sum = 0.0; + + // Each thread processes multiple K elements and accumulates + for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { + let tile_size = min(TILE_K, params.k - k_tile); + + // Cooperatively load vector tile into shared memory (all threads) + for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { + shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + } + + workgroupBarrier(); + + if (output_row < params.m) { + local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + } + + workgroupBarrier(); + } + + // Store partial sums and reduce within each partition + partial_sums[thread_id] = local_sum; + workgroupBarrier(); + let group_base = thread_group * THREADS_PER_OUTPUT; + let thread_base = group_base + thread_in_group; + var offset = THREADS_PER_OUTPUT / 2; + while (offset > 0) { + if (thread_in_group < offset) { + partial_sums[thread_base] += partial_sums[thread_base + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + + // Store back to global memory + if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { + dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + } +} +#end(SHADER) From 7c2b2ef237d0c877ec4115f2429feaa52853c5cc Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 7 Nov 2025 19:06:08 -0800 Subject: [PATCH 18/40] Use map for shader replacements instead of pair of strings --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 30 ++++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 2b9c28bc3b6..9e8cbc477ed 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -357,8 +357,8 @@ struct ggml_backend_webgpu_buffer_context { // Process a WGSL shader string, replacing tokens of the form {{KEY}} with // the corresponding values provided in `repls`. -static std::string ggml_webgpu_process_shader_repls(const char * src, - const std::vector> & repls) { +static std::string ggml_webgpu_process_shader_repls(const char * src, + const std::map & repls) { if (!src) { return std::string(); } @@ -1759,16 +1759,16 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); if (webgpu_ctx->supports_subgroup_matrix) { - std::vector> sg_matrix_repls; - sg_matrix_repls.emplace_back("WEBGPU_MAX_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size)); - sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K)); - sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M)); - sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N)); - sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M)); - sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_MATRIX_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N)); - sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_M_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.M)); - sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_N_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.N)); - sg_matrix_repls.emplace_back("WEBGPU_SG_MAT_K_SIZE", std::to_string(webgpu_ctx->subgroup_matrix_config.K)); + std::map sg_matrix_repls; + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); + sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); + sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); + sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); + sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); + sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); std::string proc_mul_mat_subgroup_matrix_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); @@ -1816,9 +1816,9 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; - std::vector> reg_repls; - reg_repls.emplace_back("WEBGPU_TILE_M", std::to_string(WEBGPU_MUL_MAT_TILE_M)); - reg_repls.emplace_back("WEBGPU_TILE_N", std::to_string(WEBGPU_MUL_MAT_TILE_N)); + std::map reg_repls; + reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); + reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); // Process each reg-tile shader with tile replacements. // Keep the processed strings in-scope so .c_str() remains valid. From c201d0de7743c0f45d8d16d554bb0482435a256a Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 11 Nov 2025 15:34:14 -0800 Subject: [PATCH 19/40] Wasm (#9) * webgpu : fix build on emscripten * more debugging stuff * test-backend-ops: force single thread on wasm * fix single-thread case for init_tensor_uniform * use jspi * add pthread * test: remember to set n_thread for cpu backend * Add buffer label and enable dawn-specific toggles to turn off some checks * Intermediate state * Fast working f16/f32 vec4 * Working float fast mul mat * Clean up naming of mul_mat to match logical model, start work on q mul_mat * Setup for subgroup matrix mat mul * Basic working subgroup matrix * Working subgroup matrix tiling * Handle weirder sg matrix sizes (but still % sg matrix size) * Working start to gemv * working f16 accumulation with shared memory staging * Print out available subgroup matrix configurations * Vectorize dst stores for sg matrix shader * Gemv working scalar * Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Comment on dawn toggles * Working subgroup matrix code for (semi)generic sizes * Remove some comments * Cleanup code * Update dawn version and move to portable subgroup size * Try to fix new dawn release * Update subgroup size comment * Only check for subgroup matrix configs if they are supported * Add toggles for subgroup matrix/f16 support on nvidia+vulkan * Make row/col naming consistent * Refactor shared memory loading * Move sg matrix stores to correct file * Working q4_0 * Formatting * Work with emscripten builds * Fix test-backend-ops emscripten for f16/quantized types * Use emscripten memory64 to support get_memory * Add build flags and try ci --------- Co-authored-by: Xuan Son Nguyen --- .github/workflows/build.yml | 40 +++++ .gitignore | 2 + CMakeLists.txt | 12 +- common/arg.cpp | 4 + common/common.cpp | 2 + ggml/CMakeLists.txt | 2 +- ggml/src/ggml-webgpu/CMakeLists.txt | 22 ++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 229 ++++++++++++++------------- scripts/serve-static.js | 110 +++++++++++++ tests/test-backend-ops.cpp | 64 +++++--- 10 files changed, 348 insertions(+), 139 deletions(-) create mode 100644 scripts/serve-static.js diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 36084c55078..81d57b039d4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -547,6 +547,46 @@ jobs: # This is using llvmpipe and runs slower than other backends ctest -L main --verbose --timeout 3600 + ubuntu-24-wasm-webgpu: + runs-on: ubuntu-24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ubuntu-latest-wasm-webgpu + evict-old-files: 1d + + - name: Install Emscripten + run: | + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + + - name: Fetch emdawnwebgpu + run: | + DAWN_TAG="v20251027.212519" + EMDAWN_PKG="emdawnwebgpu_pkg-${DAWN_TAG}.zip" + echo "Downloading ${EMDAWN_PKG}" + curl -L -o emdawn.zip \ + "https://github.com/google/dawn/releases/download/${DAWN_TAG}/${EMDAWN_PKG}" + unzip emdawn.zip + + - name: Build WASM WebGPU + run: | + source emsdk/emsdk_env.sh + emcmake cmake -B build-wasm \ + -DGGML_WEBGPU=ON \ + -DLLAMA_CURL=OFF \ + -DEMDAWNWEBGPU_DIR=emdawnwebgpu_pkg + + cmake --build build-wasm --target test-backend-ops -j $(nproc) + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.1.2 diff --git a/.gitignore b/.gitignore index c7d00097857..33f469020c4 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,5 @@ poetry.toml # IDE *.code-workspace .windsurf/ +# emscripten +a.out.* diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bf8b2789ae..1c69a865b93 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,7 +36,17 @@ option(LLAMA_USE_SYSTEM_GGML "Use system libggml" OFF) if (EMSCRIPTEN) set(BUILD_SHARED_LIBS_DEFAULT OFF) - option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) + # Use 64-bit memory to support backend_get_memory queries + # TODO: analyze performance impact, see https://spidermonkey.dev/blog/2025/01/15/is-memory64-actually-worth-using + add_compile_options("-sMEMORY64=1") + add_link_options("-sMEMORY64=1") + add_link_options("-sALLOW_MEMORY_GROWTH=1") + + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" OFF) + option(LLAMA_BUILD_HTML "llama: build HTML file" ON) + if (LLAMA_BUILD_HTML) + set(CMAKE_EXECUTABLE_SUFFIX ".html") + endif() else() if (MINGW) set(BUILD_SHARED_LIBS_DEFAULT OFF) diff --git a/common/arg.cpp b/common/arg.cpp index 4316917d745..8f008281c05 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -39,6 +39,7 @@ #include "http.h" #endif +#ifndef __EMSCRIPTEN__ #ifdef __linux__ #include #elif defined(_WIN32) @@ -50,8 +51,11 @@ #else #include #endif +#endif + #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + // isatty #if defined(_WIN32) #include diff --git a/common/common.cpp b/common/common.cpp index b0591e84b06..d12feffce09 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -889,6 +889,8 @@ std::string fs_get_cache_directory() { cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 181f179ed17..fd3bb2124a0 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -224,7 +224,7 @@ option(GGML_WEBGPU "ggml: use WebGPU" option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF) option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)" OFF) option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) - +option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index c6a95d51512..3ccce58aa39 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -39,8 +39,23 @@ add_dependencies(ggml-webgpu generate_shaders) if(EMSCRIPTEN) set(EMDAWNWEBGPU_DIR "" CACHE PATH "Path to emdawnwebgpu_pkg") - target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") - target_link_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + if(NOT EMDAWNWEBGPU_DIR) + # default built-in port + target_compile_options(ggml-webgpu PRIVATE "--use-port=emdawnwebgpu") + target_link_options(ggml-webgpu INTERFACE "--use-port=emdawnwebgpu") + else() + # custom port + target_compile_options(ggml-webgpu PRIVATE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + target_link_options(ggml-webgpu INTERFACE "--use-port=${EMDAWNWEBGPU_DIR}/emdawnwebgpu.port.py") + endif() + + if (GGML_WEBGPU_JSPI) + target_compile_options(ggml-webgpu PRIVATE "-fwasm-exceptions") + target_link_options(ggml-webgpu INTERFACE "-sJSPI" "-fwasm-exceptions") + else() + target_compile_options(ggml-webgpu PRIVATE "-fexceptions") + target_link_options(ggml-webgpu INTERFACE "-sASYNCIFY" "-exceptions") + endif() else() find_package(Dawn REQUIRED) set(DawnWebGPU_TARGET dawn::webgpu_dawn) @@ -48,6 +63,9 @@ endif() if (GGML_WEBGPU_DEBUG) target_compile_definitions(ggml-webgpu PRIVATE GGML_WEBGPU_DEBUG=1) + if(EMSCRIPTEN) + target_link_options(ggml-webgpu INTERFACE "-sASSERTIONS=2") + endif() endif() if (GGML_WEBGPU_CPU_PROFILE) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 9e8cbc477ed..a7476b109df 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -9,6 +9,10 @@ #include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" +#ifdef __EMSCRIPTEN__ +# include +#endif + #include #include @@ -261,9 +265,12 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; + uint32_t subgroup_size; + +#ifndef __EMSCRIPTEN__ bool supports_subgroup_matrix = false; - uint32_t subgroup_size; wgpu::SubgroupMatrixConfig subgroup_matrix_config; +#endif // Separate this out from limits since on some Metal systems, the limit returned by // querying the limits is higher than the actual allowed maximum. @@ -449,8 +456,8 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, // inflight_max may be 0, meaning that we must wait on all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint inflight_threads = ctx->inflight_threads; - uint inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); while (futures.size() >= inflight_max && futures.size() > 0) { ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); @@ -986,6 +993,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; +#ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = @@ -995,11 +1003,15 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; } else { +#endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; +#ifndef __EMSCRIPTEN__ } +#endif + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } } @@ -1419,9 +1431,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads - uint inflight_threads = ctx->inflight_threads; - uint batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); + uint32_t inflight_threads = ctx->inflight_threads; + uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), + WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); if (commands.size() >= batch_size) { futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); // Process events and check for completed submissions @@ -1758,6 +1770,17 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + std::string proc_mul_mat_f32_f32; + std::string proc_mul_mat_f32_f32_vec; + std::string proc_mul_mat_f16_f32; + std::string proc_mul_mat_f16_f32_vec; + std::string proc_mul_mat_f16_f16; + std::string proc_mul_mat_f16_f16_vec; + std::string proc_mul_mat_q4_0_f32; + std::string proc_mul_mat_q4_0_f32_vec; + + std::vector mul_mat_constants; +#ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); @@ -1770,100 +1793,57 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); - std::string proc_mul_mat_subgroup_matrix_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f32_f32_vec = + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f32_vec = + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_f16_f16_vec = + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32 = + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec = + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f32_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32.c_str(), "mul_mat_subgroup_matrix_f16_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f32_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f32_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16.c_str(), "mul_mat_subgroup_matrix_f16_f16"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(), - "mul_mat_subgroup_matrix_f16_f16_vec"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( - webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(), - "mul_mat_subgroup_matrix_q4_0_f32_vec"); } else { - std::vector mul_mat_reg_tile_constants(3); - mul_mat_reg_tile_constants[0].key = "TILE_K"; - mul_mat_reg_tile_constants[0].value = WEBGPU_MUL_MAT_TILE_K; - mul_mat_reg_tile_constants[1].key = "WORKGROUP_SIZE_M"; - mul_mat_reg_tile_constants[1].value = WEBGPU_MUL_MAT_WG_SIZE_M; - mul_mat_reg_tile_constants[2].key = "WORKGROUP_SIZE_N"; - mul_mat_reg_tile_constants[2].value = WEBGPU_MUL_MAT_WG_SIZE_N; +#endif + mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); + mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); std::map reg_repls; reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - // Process each reg-tile shader with tile replacements. - // Keep the processed strings in-scope so .c_str() remains valid. - std::string proc_mul_mat_reg_tile_f32_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - std::string proc_mul_mat_reg_tile_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - std::string proc_mul_mat_reg_tile_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(), - "mul_mat_reg_tile_f32_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32_vec.c_str(), - "mul_mat_reg_tile_f32_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32.c_str(), - "mul_mat_reg_tile_f16_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f32_vec.c_str(), - "mul_mat_reg_tile_f16_f32_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16.c_str(), - "mul_mat_reg_tile_f16_f16", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(), - "mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(), - "mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(), - "mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants); + proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); + proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); + proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); + proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); + proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); + proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); + proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); + proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); +#ifndef __EMSCRIPTEN__ } +#endif + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2384,13 +2364,17 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t webgpu_context ctx = reg_ctx->webgpu_ctx; + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; wgpu::DawnTogglesDescriptor adapterTogglesDesc; adapterTogglesDesc.enabledToggles = adapterEnabledToggles; adapterTogglesDesc.enabledToggleCount = 2; - wgpu::RequestAdapterOptions options = {}; options.nextInChain = &adapterTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { @@ -2406,11 +2390,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ctx->adapter.GetLimits(&ctx->limits); ctx->max_wg_size_x = 288; // default value - wgpu::AdapterInfo info{}; + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { info.nextInChain = &subgroup_matrix_configs; } +#endif ctx->adapter.GetInfo(&info); wgpu::SupportedFeatures features; @@ -2418,6 +2404,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t // we require f16 support GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); +#ifndef __EMSCRIPTEN__ // Only support square f16 matrices of size 8 or 16 for now bool valid_subgroup_matrix_config = false; if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { @@ -2433,36 +2420,27 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t } } + ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; + ctx->subgroup_size = info.subgroupMaxSize; // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::ImplicitDeviceSynchronization }; + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->supports_subgroup_matrix) { required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } +#endif #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif - // Enable Dawn-specific toggles to increase native performance - // TODO: Don't enable for WASM builds, they won't have an effect anyways - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - wgpu::DeviceDescriptor dev_desc; dev_desc.requiredLimits = &ctx->limits; dev_desc.requiredFeatures = required_features.data(); @@ -2480,7 +2458,23 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), std::string(message).c_str()); }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + dev_desc.nextInChain = &deviceTogglesDesc; +#endif + ctx->instance.WaitAny(ctx->adapter.RequestDevice( &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { @@ -2578,18 +2572,27 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; - const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; - - wgpu::DawnTogglesDescriptor instanceTogglesDesc; - instanceTogglesDesc.enabledToggles = instanceEnabledToggles; - instanceTogglesDesc.enabledToggleCount = 1; wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); instance_descriptor.requiredFeatureCount = instance_features.size(); - instance_descriptor.nextInChain = &instanceTogglesDesc; + +#ifndef __EMSCRIPTEN__ + const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" }; + wgpu::DawnTogglesDescriptor instanceTogglesDesc; + instanceTogglesDesc.enabledToggles = instanceEnabledToggles; + instanceTogglesDesc.enabledToggleCount = 1; + instance_descriptor.nextInChain = &instanceTogglesDesc; +#endif webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + +#ifdef __EMSCRIPTEN__ + if (webgpu_ctx->instance == nullptr) { + GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); + return nullptr; + } +#endif GGML_ASSERT(webgpu_ctx->instance != nullptr); static ggml_backend_reg reg = { diff --git a/scripts/serve-static.js b/scripts/serve-static.js new file mode 100644 index 00000000000..df6cf534055 --- /dev/null +++ b/scripts/serve-static.js @@ -0,0 +1,110 @@ +const http = require('http'); +const fs = require('fs').promises; +const path = require('path'); + +// This file is used for testing wasm build from emscripten +// Example build command: +// emcmake cmake -B build-wasm -DGGML_WEBGPU=ON -DLLAMA_CURL=OFF +// cmake --build build-wasm --target test-backend-ops -j + +const PORT = 8080; +const STATIC_DIR = path.join(__dirname, '../build-wasm/bin'); +console.log(`Serving static files from: ${STATIC_DIR}`); + +const mimeTypes = { + '.html': 'text/html', + '.js': 'text/javascript', + '.css': 'text/css', + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.gif': 'image/gif', + '.svg': 'image/svg+xml', + '.json': 'application/json', + '.woff': 'font/woff', + '.woff2': 'font/woff2', +}; + +async function generateDirListing(dirPath, reqUrl) { + const files = await fs.readdir(dirPath); + let html = ` + + + + Directory Listing + + + +

Directory: ${reqUrl}

+
    + `; + + if (reqUrl !== '/') { + html += `
  • ../ (Parent Directory)
  • `; + } + + for (const file of files) { + const filePath = path.join(dirPath, file); + const stats = await fs.stat(filePath); + const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : ''); + html += `
  • ${file}${stats.isDirectory() ? '/' : ''}
  • `; + } + + html += ` +
+ + + `; + return html; +} + +const server = http.createServer(async (req, res) => { + try { + // Set COOP and COEP headers + res.setHeader('Cross-Origin-Opener-Policy', 'same-origin'); + res.setHeader('Cross-Origin-Embedder-Policy', 'require-corp'); + res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate'); + res.setHeader('Pragma', 'no-cache'); + res.setHeader('Expires', '0'); + + const filePath = path.join(STATIC_DIR, decodeURIComponent(req.url)); + const stats = await fs.stat(filePath); + + if (stats.isDirectory()) { + const indexPath = path.join(filePath, 'index.html'); + try { + const indexData = await fs.readFile(indexPath); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(indexData); + } catch { + // No index.html, generate directory listing + const dirListing = await generateDirListing(filePath, req.url); + res.writeHeader(200, { 'Content-Type': 'text/html' }); + res.end(dirListing); + } + } else { + const ext = path.extname(filePath).toLowerCase(); + const contentType = mimeTypes[ext] || 'application/octet-stream'; + const data = await fs.readFile(filePath); + res.writeHeader(200, { 'Content-Type': contentType }); + res.end(data); + } + } catch (err) { + if (err.code === 'ENOENT') { + res.writeHeader(404, { 'Content-Type': 'text/plain' }); + res.end('404 Not Found'); + } else { + res.writeHeader(500, { 'Content-Type': 'text/plain' }); + res.end('500 Internal Server Error'); + } + } +}); + +server.listen(PORT, () => { + console.log(`Server running at http://localhost:${PORT}/`); +}); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 967a53c63d8..e949b64a0a0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -40,12 +41,18 @@ #include #include +#ifdef __EMSCRIPTEN__ +# define N_THREADS 1 +#else +# define N_THREADS std::thread::hardware_concurrency() +#endif + static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); std::vector data(nels); { // parallel initialization - static const size_t n_threads = std::thread::hardware_concurrency(); + static const size_t n_threads = N_THREADS; // static RNG initialization (revisit if n_threads stops being constant) static std::vector generators = []() { std::random_device rd; @@ -64,15 +71,19 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } }; - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*nels/n_threads; - size_t end = (i+1)*nels/n_threads; - tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); - } - for (auto & t : tasks) { - t.get(); + if (n_threads == 1) { + init_thread(0, 0, nels); + } else { + std::vector> tasks; + tasks.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*nels/n_threads; + size_t end = (i+1)*nels/n_threads; + tasks.push_back(std::async(std::launch::async, init_thread, i, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } @@ -104,17 +115,23 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m }; const size_t min_blocks_per_thread = 1; - const size_t n_threads = std::min(std::thread::hardware_concurrency()/2, - std::max(1, n_blocks / min_blocks_per_thread)); - std::vector> tasks; - tasks.reserve(n_threads); - for (size_t i = 0; i < n_threads; i++) { - size_t start = i*n_blocks/n_threads; - size_t end = (i+1)*n_blocks/n_threads; - tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); - } - for (auto & t : tasks) { - t.get(); + const size_t n_quant_threads = std::min(std::max(N_THREADS/2, 1), + std::max(1, n_blocks / min_blocks_per_thread)); + + if (n_quant_threads == 1) { + // single-threaded quantization: do all blocks in the current thread + quantize_thread(0, n_blocks); + } else { + std::vector> tasks; + tasks.reserve(n_quant_threads); + for (size_t i = 0; i < n_quant_threads; i++) { + size_t start = i*n_blocks/n_quant_threads; + size_t end = (i+1)*n_blocks/n_quant_threads; + tasks.push_back(std::async(std::launch::async, quantize_thread, start, end)); + } + for (auto & t : tasks) { + t.get(); + } } } ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); @@ -7522,6 +7539,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op return false; } + // TODO: find a better way to set the number of threads for the CPU backend + ggml_backend_cpu_set_n_threads(backend_cpu, N_THREADS); + size_t n_ok = 0; size_t tests_run = 0; std::vector failed_tests; @@ -7799,7 +7819,7 @@ int main(int argc, char ** argv) { auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); if (ggml_backend_set_n_threads_fn) { // TODO: better value for n_threads - ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency()); + ggml_backend_set_n_threads_fn(backend, N_THREADS); } size_t free, total; // NOLINT From 56e6959f93bc3f0701f25c197223537a21b1b05f Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 11 Nov 2025 21:40:57 -0800 Subject: [PATCH 20/40] Remove extra whitespace --- scripts/serve-static.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/serve-static.js b/scripts/serve-static.js index df6cf534055..8ddc04aad98 100644 --- a/scripts/serve-static.js +++ b/scripts/serve-static.js @@ -43,18 +43,18 @@ async function generateDirListing(dirPath, reqUrl) {

Directory: ${reqUrl}

    `; - + if (reqUrl !== '/') { html += `
  • ../ (Parent Directory)
  • `; } - + for (const file of files) { const filePath = path.join(dirPath, file); const stats = await fs.stat(filePath); const link = encodeURIComponent(file) + (stats.isDirectory() ? '/' : ''); html += `
  • ${file}${stats.isDirectory() ? '/' : ''}
  • `; } - + html += `
From 5bcd57722f2dbfc2f1e09447e11e3dee9c4b7c1c Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 14 Nov 2025 09:53:44 -0800 Subject: [PATCH 21/40] Move wasm single-thread logic out of test-backend-ops for cpu backend --- ggml/include/ggml.h | 6 ++++++ ggml/src/ggml-cpu/ggml-cpu.cpp | 7 +++++++ tests/test-backend-ops.cpp | 4 ---- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 605fcfcb9c2..47bfa645361 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -225,7 +225,13 @@ # define GGML_MAX_NAME 64 #endif +// For single-thread WASM builds, only use 1 thread +#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) #define GGML_DEFAULT_N_THREADS 4 +#else +#define GGML_DEFAULT_N_THREADS 1 +#endif + #define GGML_DEFAULT_GRAPH_SIZE 2048 #if UINTPTR_MAX == 0xFFFFFFFF diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 3191faaa4cd..3c5cc141022 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -246,8 +246,11 @@ bool ggml_backend_is_cpu(ggml_backend_t backend) { void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); +// For single-thread WASM builds, do not allow changing the number of threads +#if !defined(_EMSCRIPTEN_) || defined(__EMSCRIPTEN_PTHREADS__) struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; ctx->n_threads = n_threads; +#endif } void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) { @@ -622,10 +625,14 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r } static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) { + +// For single-thread WASM builds, do not expose a set_n_threads function +#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) if (strcmp(name, "ggml_backend_set_n_threads") == 0) { ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads; return (void *)fct; } +#endif if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_cpu_device_get_extra_buffers_type; return (void *)fct; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a1df5d66275..5e95e411e92 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include @@ -7905,9 +7904,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op return false; } - // TODO: find a better way to set the number of threads for the CPU backend - ggml_backend_cpu_set_n_threads(backend_cpu, N_THREADS); - size_t n_ok = 0; size_t tests_run = 0; std::vector failed_tests; From f9ba81953b4dd1ac9103baf1222fc3a6de1462b5 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 17 Nov 2025 13:53:00 -0800 Subject: [PATCH 22/40] Disable multiple threads for emscripten single-thread builds in ggml_graph_plan --- ggml/src/ggml-cpu/ggml-cpu.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index c7348cc26c1..45b42926776 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2701,6 +2701,11 @@ struct ggml_cplan ggml_graph_plan( n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; } +#if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) + // Emscripten without pthreads support can only use a single thread + n_threads = 1; +#endif + size_t work_size = 0; struct ggml_cplan cplan; From 5ca9b5e49ea7cddc9ab7c8b43a11a9c76a4dff4a Mon Sep 17 00:00:00 2001 From: neha-ha <137219201+neha-ha@users.noreply.github.com> Date: Tue, 18 Nov 2025 12:17:00 -0800 Subject: [PATCH 23/40] Refactored pipelines and workgroup calculations (#10) * refactored pipelines * refactored workgroup calculation * removed commented out block of prior maps * Clean up ceiling division pattern --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 716 +++++++++++++-------------- 1 file changed, 353 insertions(+), 363 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a7476b109df..1c272f069c2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -25,6 +25,9 @@ #include #include +#define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) +#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 32 @@ -64,6 +67,9 @@ /* Constants */ +// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed. +#define WEBGPU_MAX_WG_SIZE 288 + #define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 32u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u @@ -272,36 +278,32 @@ struct webgpu_context_struct { wgpu::SubgroupMatrixConfig subgroup_matrix_config; #endif - // Separate this out from limits since on some Metal systems, the limit returned by - // querying the limits is higher than the actual allowed maximum. - uint32_t max_wg_size_x; - std::recursive_mutex mutex; std::atomic_uint inflight_threads = 0; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - webgpu_pipeline memset_pipeline; + std::map memset_pipelines; // variant or type index std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - webgpu_pipeline mul_mat_pipeline[30][2]; - webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized - webgpu_pipeline get_rows_pipeline[30]; - webgpu_pipeline get_rows_f32_no_vec_pipeline; - webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type - webgpu_pipeline add_pipeline[2][2]; // type, inplace - webgpu_pipeline sub_pipeline[2][2]; // type, inplace - webgpu_pipeline mul_pipeline[2][2]; // type, inplace - webgpu_pipeline div_pipeline[2][2]; // type, inplace - webgpu_pipeline rms_norm_pipeline[2]; // inplace - webgpu_pipeline rope_pipeline[2][2][2]; // type, ff, inplace - webgpu_pipeline glu_pipeline[7][2][2]; // glu-op, type, split - webgpu_pipeline scale_pipeline[2]; // inplace - webgpu_pipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace + std::map> set_rows_pipelines; // dst_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized + + std::map> cpy_pipelines; // src_type, dst_type + std::map> add_pipelines; // type, inplace + std::map> sub_pipelines; // type, inplace + std::map> mul_pipelines; // type, inplace + std::map> div_pipelines; // type, inplace + + std::map rms_norm_pipelines; // inplace + std::map>> rope_pipelines; // type, ff, inplace + std::map>> glu_pipelines; // glu_op, type, split + std::map scale_pipelines; // inplace + std::map>> soft_max_pipelines; // mask_type, has_sink, inplace size_t memset_bytes_per_thread; @@ -381,35 +383,10 @@ static std::string ggml_webgpu_process_shader_repls(const char * return s; } -static void ggml_webgpu_create_pipeline(wgpu::Device & device, - webgpu_pipeline & pipeline, - const char * shader_code, - const char * label, - const std::vector & constants = {}) { - wgpu::ShaderSourceWGSL shader_source; - shader_source.code = shader_code; - - wgpu::ShaderModuleDescriptor shader_desc; - shader_desc.nextInChain = &shader_source; - - wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); - - wgpu::ComputePipelineDescriptor pipeline_desc; - pipeline_desc.label = label; - pipeline_desc.compute.module = shader_module; - pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code - pipeline_desc.layout = nullptr; // nullptr means auto layout - if (constants.size() > 0) { - pipeline_desc.compute.constants = constants.data(); - pipeline_desc.compute.constantCount = constants.size(); - } - pipeline = { device.CreateComputePipeline(&pipeline_desc), label }; -} - -static webgpu_pipeline ggml_webgpu_create_pipeline2(wgpu::Device & device, - const char * shader_code, - const char * label, - const std::vector & constants = {}) { +static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; @@ -678,10 +655,10 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, std::vector entries = { { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } }; - size_t bytes_per_wg = ctx->max_wg_size_x * ctx->memset_bytes_per_thread; - uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipeline, params, entries, wg_x); + webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x); std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) }; ggml_backend_webgpu_wait(ctx, futures); } @@ -763,8 +740,7 @@ static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor } static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return (ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t) + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & - ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1); + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); } // Used to determine if two tensors are the same for in-place operations @@ -800,9 +776,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); + return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x); } static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, @@ -851,10 +826,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } }; - size_t max_wg_size = ctx->max_wg_size_x; - int vectorized = src->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized]; + webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized]; uint32_t threads; if (vectorized) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); @@ -862,7 +835,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } - uint32_t wg_x = (threads + max_wg_size - 1) / max_wg_size; + uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); } @@ -902,13 +875,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (dst->ne[1] * dst->ne[2] * dst->ne[3] + max_wg_size - 1) / max_wg_size; + uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE); - webgpu_pipeline pipeline = ctx->get_rows_pipeline[src->type]; - if (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 != 0) { - pipeline = ctx->get_rows_f32_no_vec_pipeline; - } + uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; + webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -950,10 +920,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; - webgpu_pipeline pipeline = ctx->mul_mat_pipeline[src0->type][src1->type]; + webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0]; - uint32_t wg_x = - (dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE; + uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE); uint32_t wg_y = 1; bool use_fast = false; @@ -980,15 +949,13 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; if (dst->ne[1] == 1) { // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = - (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) / - ctx->limits.maxComputeWorkgroupsPerDimension; + vectorized = vectorized && (src0->type < 2); + pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension); } else { pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; @@ -998,16 +965,16 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; - wg_m = (dst->ne[0] + wg_m_sg_tile - 1) / wg_m_sg_tile; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; - wg_n = (dst->ne[1] + wg_n_sg_tile - 1) / wg_n_sg_tile; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = (dst->ne[0] + tile_m_s - 1) / tile_m_s; - wg_n = (dst->ne[1] + tile_n_s - 1) / tile_n_s; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + wg_n = CEIL_DIV(dst->ne[1], tile_n_s); #ifndef __EMSCRIPTEN__ } #endif @@ -1059,8 +1026,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -1096,7 +1062,7 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src)); } static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, @@ -1181,9 +1147,8 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - webgpu_pipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; + webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -1234,9 +1199,8 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - webgpu_pipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -1273,9 +1237,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - size_t max_wg_size = ctx->max_wg_size_x; - uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; - return ggml_backend_webgpu_build(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -1347,7 +1310,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries, + return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, ggml_nrows(dst)); } @@ -1382,22 +1345,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace); } case GGML_OP_SUB: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace); } case GGML_OP_MUL: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace); } case GGML_OP_DIV: { int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); + return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace); } case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); @@ -1641,8 +1604,7 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, - (size + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1), + ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, buf_name.c_str()); @@ -1717,58 +1679,64 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_wg_size = webgpu_ctx->max_wg_size_x; - size_t max_threads = max_wg_size * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = - (webgpu_ctx->limits.maxStorageBufferBindingSize + max_threads - 1) / max_threads; + webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = max_wg_size; - constants[1].key = "bytes_per_thread"; - constants[1].value = webgpu_ctx->memset_bytes_per_thread; - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->memset_pipeline, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = webgpu_ctx->memset_bytes_per_thread; + webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants); } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32], - wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32], - wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32], - wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32], - wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32], - wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32], - wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32], - wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32], - wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32], - wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32], - wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32], - wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32], - wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32], - wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32], - wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32], - wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32], - wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32], - wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32], - wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32], - wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + // Q4/Q5/Q8 classic quantizations + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); + + // K-quantizations + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); + + // IQ quantizations (2-, 3-, 4-bit variants) + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); + + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); + + // 1-bit and 4-bit IQ variants + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); std::string proc_mul_mat_f32_f32; std::string proc_mul_mat_f32_f32_vec; @@ -1828,21 +1796,21 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { } #endif - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); @@ -1853,257 +1821,280 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2( + webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16, - "set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec, - "set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x)); + webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); + webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec, - "get_rows_f32_vec", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32, - "get_rows_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F16], wgsl_get_rows_f16, - "get_rows_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_I32], wgsl_get_rows_i32, - "get_rows_i32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_0], wgsl_get_rows_q4_0, - "get_rows_q4_0", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_1], wgsl_get_rows_q4_1, - "get_rows_q4_1", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_0], wgsl_get_rows_q5_0, - "get_rows_q5_0", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_1], wgsl_get_rows_q5_1, - "get_rows_q5_1", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q8_0], wgsl_get_rows_q8_0, - "get_rows_q8_0", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q2_K], wgsl_get_rows_q2_k, - "get_rows_q2_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q3_K], wgsl_get_rows_q3_k, - "get_rows_q3_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q4_K], wgsl_get_rows_q4_k, - "get_rows_q4_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q5_K], wgsl_get_rows_q5_k, - "get_rows_q5_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_Q6_K], wgsl_get_rows_q6_k, - "get_rows_q6_k", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XXS], - wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_XS], - wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ2_S], wgsl_get_rows_iq2_s, - "get_rows_iq2_s", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_XXS], - wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ3_S], wgsl_get_rows_iq3_s, - "get_rows_iq3_s", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_S], wgsl_get_rows_iq1_s, - "get_rows_iq1_s", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ1_M], wgsl_get_rows_iq1_m, - "get_rows_iq1_m", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_NL], - wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_IQ4_XS], - wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); + + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], - wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], - wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], - wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], - wgsl_cpy_f16_f16, "cpy_f16_f16", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace, - "add_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace, - "add_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants); + webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants); + webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); + webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); } static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace, - "sub_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace, - "sub_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants); + webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants); + webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); + webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace, - "mul_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace, - "mul_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants); + webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants); + webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); + webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); } static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace, - "div_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace, - "div_f16_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants); + webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants); + webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); + webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, - "rms_norm_inplace", constants); + + webgpu_ctx->rms_norm_pipelines[0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants); + webgpu_ctx->rms_norm_pipelines[1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); } static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, - "rope_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], - wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff, - "rope_f32_ff", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1], - wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16, - "rope_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1], - wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff, - "rope_f16_ff", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1], - wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); } static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - // reglu - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], - wgsl_reglu_f32, "reglu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0], - wgsl_reglu_f16, "reglu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1], - wgsl_reglu_f32_split, "reglu_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1], - wgsl_reglu_f16_split, "reglu_f16_split", constants); - // geglu - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0], - wgsl_geglu_f32, "geglu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0], - wgsl_geglu_f16, "geglu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1], - wgsl_geglu_f32_split, "geglu_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1], - wgsl_geglu_f16_split, "geglu_f16_split", constants); - // swiglu - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0], - wgsl_swiglu_f32, "swiglu_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0], - wgsl_swiglu_f16, "swiglu_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1], - wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1], - wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - // swiglu_oai - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0], - wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1], - wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - // geglu_erf - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0], - wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0], - wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1], - wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1], - wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - // geglu_quick - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0], - wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0], - wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1], - wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1], - wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + // REGLU + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); + + // GEGLU + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); + + // SWIGLU + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + + // SWIGLU_OAI + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + + // GEGLU_ERF + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + + // GEGLU_QUICK + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, - "scale_f32_inplace", constants); + std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); + + webgpu_ctx->scale_pipelines[0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants); + webgpu_ctx->scale_pipelines[1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants); } static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32, - "soft_max_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace, - "soft_max_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink, - "soft_max_f32_sink", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1], - wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32, - "soft_max_f32_mask_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1], - wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16, - "soft_max_f32_mask_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1], - wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0], - wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1], - wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", - constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0], - wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1], - wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", - constants); + + // f32 (no mask) + webgpu_ctx->soft_max_pipelines[2][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); + webgpu_ctx->soft_max_pipelines[2][0][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); + webgpu_ctx->soft_max_pipelines[2][1][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); + webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); + + // f32 mask (mask_type = 0) + webgpu_ctx->soft_max_pipelines[0][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); + webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); + webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); + webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants); + + // f16 mask (mask_type = 1) + webgpu_ctx->soft_max_pipelines[1][0][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); + webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); + webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); + webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -2388,7 +2379,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t GGML_ASSERT(ctx->adapter != nullptr); ctx->adapter.GetLimits(&ctx->limits); - ctx->max_wg_size_x = 288; // default value wgpu::AdapterInfo info{}; #ifndef __EMSCRIPTEN__ From e35099e632fd98980f4dcf7728bd5385926e36ab Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 3 Dec 2025 11:23:24 -0800 Subject: [PATCH 24/40] Start work on flash attention --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 106 ++++++++++++++++++ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 83 ++++++++++++++ 2 files changed, 189 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e684db9e210..f15faf2ee97 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -108,6 +108,11 @@ #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 #define WEBGPU_MUL_MAT_VEC_TILE_K 256 +// Flash Attention parameters +#define WEBGPU_FLASH_ATTN_WG_SIZE 32 +#define WEBGPU_FLASH_ATTN_Q_TILE 8 +#define WEBGPU_FLASH_ATTN_KV_TILE 16 + /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. @@ -290,6 +295,8 @@ struct webgpu_context_struct { std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + webgpu_pipeline flash_attn_pipeline; + std::map> set_rows_pipelines; // dst_type, vectorized std::map> get_rows_pipelines; // src_type, vectorized @@ -986,6 +993,95 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { +// For now we assume everything (mask, sink) + float max_bias; + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // print type and dimensions of Q/K/V/mask/sinks/dst + std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] + << ", " << Q->ne[3] << "]\n"; + std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] + << ", " << K->ne[3] << "]\n"; + std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] + << ", " << V->ne[3] << "]\n"; + std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] + << ", " << mask->ne[3] << "]\n"; + std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] + << ", " << sinks->ne[3] << "]\n"; + std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] + << ", " << dst->ne[3] << "]\n"; + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) Q->ne[0], // head dimension (Q/K) + (uint32_t) V->ne[0], // head dimension (V) + (uint32_t) Q->ne[2], // number of heads + (uint32_t) Q->ne[1], // sequence length (Q) + (uint32_t) K->ne[1], // sequence length (K/V) + (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 + (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 + (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 + (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 + (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 + (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 + (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 + (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 + (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 + *(uint32_t *) dst->op_params, // scale + *(uint32_t *) &max_bias, + *(uint32_t *) &n_head_log2, + *(uint32_t *) &m0, + *(uint32_t *) &m1 + + }; + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, + { .binding = 4, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }, + { .binding = 5, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], WEBGPU_FLASH_ATTN_Q_TILE); + uint32_t wg_x = wg_per_head * Q->ne[2]; // wg per head * number of heads + std::cout << "ggml_webgpu_flash_attn: wg_x: " << wg_x << "\n"; + return ggml_backend_webgpu_build(ctx, ctx->flash_attn_pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); ggml_unary_op unary_op = ggml_get_unary_op(dst); @@ -1389,6 +1485,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); + case GGML_OP_FLASH_ATTN_EXT: + return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); @@ -1886,6 +1984,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } +static void ggml_webgpu_init_flash_attn_pipeline(webgpu_context & webgpu_ctx) { + webgpu_ctx->flash_attn_pipeline = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_flash_attn, "flash_attn"); +} + static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); @@ -2471,6 +2573,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } break; } + case GGML_OP_FLASH_ATTN_EXT: + supports_op = true; + break; case GGML_OP_RMS_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; @@ -2744,6 +2849,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_memset_pipeline(ctx); ggml_webgpu_init_mul_mat_pipeline(ctx); + ggml_webgpu_init_flash_attn_pipeline(ctx); ggml_webgpu_init_set_rows_pipeline(ctx); ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl new file mode 100644 index 00000000000..0f7d34cf13d --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -0,0 +1,83 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + head_dim_qk: u32, + head_dim_v: u32, + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + + // TODO: still need to consider broadcast + + // softmax params + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array; +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +@group(0) @binding(5) var dst: array; + +// The number of Q rows processed per workgroup +const Q_TILE = 8u; +var q_shmem: array; // assuming max head_dim_qk of 64 + +const KV_TILE = 16u; +// we can reuse the same shmem for K and V since we only need one at a time right? +var kv_shmem: array; // assuming max head_dim_qkv of 64 + +const WG_SIZE = 32; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32) { + + dst[0] = Q[0] + K[0] + V[0] + f32(mask[0]) + sinks[0]; // dummy line to avoid compile error + + // workgroups per head + // batch index + + // load q into shared memory + + // for each kv tile + // load k into shared memory + + // compute qk scores + // apply mask + // softmax + + // load v into shared memory + + // compute output + + // write output to dst +} \ No newline at end of file From cc5ff8686ff7102e1376484464a16afb2a49f4c2 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 16 Dec 2025 14:05:20 -0800 Subject: [PATCH 25/40] Shader structure set up (many bugs still) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 236 ++++++++++++++++-- 2 files changed, 221 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f15faf2ee97..765f1475ce5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1077,7 +1077,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, }; uint32_t wg_per_head = CEIL_DIV(Q->ne[1], WEBGPU_FLASH_ATTN_Q_TILE); - uint32_t wg_x = wg_per_head * Q->ne[2]; // wg per head * number of heads + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches std::cout << "ggml_webgpu_flash_attn: wg_x: " << wg_x << "\n"; return ggml_backend_webgpu_build(ctx, ctx->flash_attn_pipeline, params, entries, wg_x); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 0f7d34cf13d..edd9abf8842 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -45,39 +45,243 @@ struct Params { @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @group(0) @binding(5) var dst: array; +@group(0) @binding(6) var params: Params; // The number of Q rows processed per workgroup const Q_TILE = 8u; -var q_shmem: array; // assuming max head_dim_qk of 64 +var q_shmem: array; // assumes max head_dim_qk of 64 -const KV_TILE = 16u; +const KV_TILE = 8u; // we can reuse the same shmem for K and V since we only need one at a time right? -var kv_shmem: array; // assuming max head_dim_qkv of 64 +var k_shmem: array; // assuming max head_dim_qkv of 64 +var v_shmem: array; // assuming max head_dim_qkv of 64 -const WG_SIZE = 32; +var o_shmem: array; // output shmem + +// storage for output of Q*K^T scores for online softmax (S matrix from paper) +// also storage for diagonal matrix during online softmax (P matrix from paper) +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +const WG_SIZE = 32u; +const SUBGROUP_SIZE = 32u; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32) { + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { dst[0] = Q[0] + K[0] + V[0] + f32(mask[0]) + sinks[0]; // dummy line to avoid compile error - // workgroups per head - // batch index + // each thread maintains its own cache for softmax intermediates + var row_max = array(-65504, -65504, -65504, -65504, -65504, -65504, -65504, -65504); + var exp_sum = array(); + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = params.head_dim_v * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + // batch index + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_offset = k_batch_offset + head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + head_idx * params.stride_v2; + let wg_in_head = wg_in_batch % wg_per_head; + + // starting Q row for this workgroup + let q_row_start = wg_in_head * Q_TILE; + + // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * params.head_dim_v; + + // Which mask row to use. TODO: support broadcasting + let mask_seq_offset = params.offset_mask + q_row_start * params.stride_m1; + + let head = f32(head_idx); + let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { + let q_row = elem_idx / params.head_dim_qk; + let q_col = elem_idx % params.head_dim_qk; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + let q_val: f16 = select( + 0.0, + f16(Q[global_q_row_offset + q_col]), + head_q_row < params.seq_len_q && q_col < params.head_dim_qk); + q_shmem[elem_idx] = q_val; + } + workgroupBarrier(); + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + // load k tile into shared memory + for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { + let k_row = elem_idx / params.head_dim_qk; + let k_col = elem_idx % params.head_dim_qk; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let k_val: f16 = select( + 0.0, + f16(K[global_k_row_offset + k_col]), + global_k_row < params.seq_len_kv && k_col < params.head_dim_qk); + k_shmem[elem_idx] = k_val; + } + + workgroupBarrier(); + + // accumulate q block * k block into registers + var acc: subgroup_matrix_result; + for (var head_dim_block = 0u; head_dim_block < params.head_dim_qk; head_dim_block += 8u) { + // load q submatrix from shared memory + var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &q_shmem, + head_dim_block, + false, + params.head_dim_qk + ); + + // load k submatrix from shared memory + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &k_shmem, + head_dim_block, + true, + params.head_dim_qk + ); + + acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); + } + + // store acc to shared memory for softmax + subgroupMatrixStore(&inter_shmem, 0, acc, false, KV_TILE); + + // online softmax + for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + // calculate running max + let prev_max = row_max[q_tile_row]; + // The mask value for this Q row and K col + let mask_val = select(f16(0.0), f16(mask[mask_seq_offset + q_tile_row * params.stride_m1 + kv_tile + sg_inv_id]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); + let thread_tile_row_max = select(f16(-65504), inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * f16(params.scale) + slope * mask_val, sg_inv_id < KV_TILE); + row_max[q_tile_row] = subgroupMax(row_max[q_tile_row], max(prev_max, thread_tile_row_max)); + + // calculate running exp sum + let cur_exp = exp(prev_max - row_max[q_tile_row]); + let cur_p = select(0, exp(thread_tile_row_max - row_max[q_tile_row]), sg_inv_id < KV_TILE); + exp_sum[q_tile_row] = exp_sum[q_tile_row] * cur_exp + subgroupSum(cur_p); + + // store back to shared memory (P matrix) + if (sg_inv_id < KV_TILE) { + inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = cur_p; + } + + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { + o_shmem[q_tile_row * params.head_dim_v + elem_idx] *= cur_exp; + } + } + + // load v tile into shared memory + for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE) { + let v_row = elem_idx / params.head_dim_v; + let v_col = elem_idx % params.head_dim_v; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let v_val: f16 = select( + 0.0, + f16(V[global_v_row_offset + v_col]), + global_v_row < params.seq_len_kv && v_col < params.head_dim_v); + v_shmem[elem_idx] = v_val; + } + + workgroupBarrier(); + + // we have P (8x8 tile, or Q_TILE x KV_TILE) in inter_shmem and V (8 x head_dim_v, or KV_TILE x head_dim_v) in v_shmem + // we want to compute O += P * V + // load P submatrix from shared memory + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &inter_shmem, + 0, + false, + KV_TILE + ); + + for (var head_dim_block = 0u; head_dim_block < params.head_dim_v; head_dim_block += 8u) { + // load V submatrix from shared memory + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &v_shmem, + head_dim_block, + true, // or false? is this transposed? + params.head_dim_v + ); + + // load O submatrix from shared memory + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + &o_shmem, + head_dim_block, + false, + params.head_dim_v + ); + + // O += P * V + o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); + + // store O back to shared memory + subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, params.head_dim_v); + } + + workgroupBarrier(); + + // add sinks + for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } - // load q into shared memory + let max_val = row_max[q_tile_row]; + // for non-sink threads, exp(-65504) effectively zeroes them out + let sink_val = select(-65504.0, f16(sinks[params.offset_sinks + head_idx]), sg_inv_id == 0); + row_max[q_tile_row] = subgroupMax(max_val, sink_val); + let max_exp = exp(max_val - row_max[q_tile_row]); + let sink_exp = exp(sink_val - row_max[q_tile_row]); - // for each kv tile - // load k into shared memory + exp_sum[q_tile_row] = exp_sum[q_tile_row] * max_exp + subgroupSum(sink_exp); + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { + o_shmem[q_tile_row * params.head_dim_v + elem_idx] *= max_exp; + } - // compute qk scores - // apply mask - // softmax + } + } - // load v into shared memory + // write output back to global memory + for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } - // compute output + let scale = select(0, 1 / exp_sum[q_tile_row], exp_sum[q_tile_row] > 0); - // write output to dst + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { + dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = f32(o_shmem[q_tile_row * params.head_dim_v + elem_idx] * scale); + } + } } \ No newline at end of file From ff4badb6d03cfe7919443624278ff8742fee2e6c Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 17 Dec 2025 10:14:33 -0800 Subject: [PATCH 26/40] debugging --- ggml/src/ggml-cpu/ops.cpp | 2 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 47 ++++++++--- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 84 ++++++++++--------- tests/test-backend-ops.cpp | 10 ++- 4 files changed, 88 insertions(+), 55 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b6209588db1..f1b3bab72e6 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8051,6 +8051,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } S = S*ms + vs; // scale and increment sum with partial sum + } if (v->type == GGML_TYPE_F16) { @@ -8079,7 +8080,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // V /= S const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); - // dst indices const int i1 = iq1; const int i2 = iq2; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 765f1475ce5..3c926c0f1fc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,7 @@ #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl -# define WEBGPU_DEBUG_BUF_ELEMS 32 +# define WEBGPU_DEBUG_BUF_ELEMS 512 #else # define WEBGPU_LOG_DEBUG(msg) ((void) 0) #endif // GGML_WEBGPU_DEBUG @@ -493,12 +494,10 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { ctx->queue.Submit(1, &commands); ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); - const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange(); - std::cout << "debug data:"; - for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) { - std::cout << " " << i << ": " << debug_data[i]; - } - std::cout << "\n"; + const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); + + std::cout << "[GPU] debug[0] = " << std::fixed << std::setprecision(6) << debug_data[0] << "\n"; + ctx->debug_host_buf.Unmap(); } #endif @@ -1021,12 +1020,22 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] << ", " << dst->ne[3] << "]\n"; + uint32_t offset_q = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)); + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + uint32_t offset_mask = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)); + uint32_t offset_sinks = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)); + uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)); + + std::cout << "ggml_webgpu_flash_attn: offsets: Q=" << offset_q << ", K=" << offset_k << ", V=" << offset_v + << ", mask=" << offset_mask << ", sinks=" << offset_sinks << ", dst=" << offset_dst << "\n"; + std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)), + offset_q, + offset_k, + offset_v, + offset_mask, + offset_sinks, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) Q->ne[0], // head dimension (Q/K) (uint32_t) V->ne[0], // head dimension (V) @@ -1074,6 +1083,10 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + { .binding = 6, + .buffer = ctx->debug_dev_buf, + .offset = 0, + .size = ctx->debug_dev_buf.GetSize() } }; uint32_t wg_per_head = CEIL_DIV(Q->ne[1], WEBGPU_FLASH_ATTN_Q_TILE); @@ -1536,8 +1549,12 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str std::vector commands; std::vector futures; + bool contains_flash_attn = false; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (cgraph->nodes[i]->op == GGML_OP_FLASH_ATTN_EXT) { + contains_flash_attn = true; + } commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads @@ -1556,6 +1573,12 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); futures.push_back(new_futures); } + +#ifdef GGML_WEBGPU_DEBUG + if (contains_flash_attn) + ggml_backend_webgpu_debug(ctx); +#endif + ggml_backend_webgpu_wait(ctx, futures); ctx->inflight_threads--; WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index edd9abf8842..df67b106eee 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -45,7 +45,8 @@ struct Params { @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @group(0) @binding(5) var dst: array; -@group(0) @binding(6) var params: Params; +@group(0) @binding(6) var debug: array; +@group(0) @binding(7) var params: Params; // The number of Q rows processed per workgroup const Q_TILE = 8u; @@ -72,11 +73,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { - dst[0] = Q[0] + K[0] + V[0] + f32(mask[0]) + sinks[0]; // dummy line to avoid compile error - // each thread maintains its own cache for softmax intermediates - var row_max = array(-65504, -65504, -65504, -65504, -65504, -65504, -65504, -65504); - var exp_sum = array(); + var row_max = array(-3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38); + var exp_sum = array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); // workgroups per head/batch let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; @@ -107,10 +106,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * params.head_dim_v; // Which mask row to use. TODO: support broadcasting - let mask_seq_offset = params.offset_mask + q_row_start * params.stride_m1; + let mask_seq_offset = params.offset_mask + q_row_start * params.seq_len_kv; let head = f32(head_idx); - let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); // load q tile into shared memory for (var elem_idx = local_id.x; elem_idx < Q_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { @@ -124,6 +123,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, head_q_row < params.seq_len_q && q_col < params.head_dim_qk); q_shmem[elem_idx] = q_val; } + workgroupBarrier(); for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { @@ -167,6 +167,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // store acc to shared memory for softmax subgroupMatrixStore(&inter_shmem, 0, acc, false, KV_TILE); + workgroupBarrier(); + // online softmax for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { // no need to process rows beyond seq_len_q @@ -178,22 +180,25 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // calculate running max let prev_max = row_max[q_tile_row]; // The mask value for this Q row and K col - let mask_val = select(f16(0.0), f16(mask[mask_seq_offset + q_tile_row * params.stride_m1 + kv_tile + sg_inv_id]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); - let thread_tile_row_max = select(f16(-65504), inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * f16(params.scale) + slope * mask_val, sg_inv_id < KV_TILE); - row_max[q_tile_row] = subgroupMax(row_max[q_tile_row], max(prev_max, thread_tile_row_max)); + let mask_val = select(0.0, f32(mask[mask_seq_offset + q_tile_row * params.seq_len_kv + kv_tile + sg_inv_id]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); + let mask_term = slope * mask_val; + let thread_tile_row_max = select(-3.4e38, f32(inter_shmem[sg_inv_id + q_tile_row * KV_TILE]) * params.scale + mask_term, sg_inv_id < KV_TILE); + row_max[q_tile_row] = subgroupMax(max(row_max[q_tile_row], max(prev_max, thread_tile_row_max))); // calculate running exp sum let cur_exp = exp(prev_max - row_max[q_tile_row]); - let cur_p = select(0, exp(thread_tile_row_max - row_max[q_tile_row]), sg_inv_id < KV_TILE); - exp_sum[q_tile_row] = exp_sum[q_tile_row] * cur_exp + subgroupSum(cur_p); + let cur_p = select(0.0, exp(thread_tile_row_max - row_max[q_tile_row]), sg_inv_id < KV_TILE); + exp_sum[q_tile_row] = exp_sum[q_tile_row] * cur_exp + subgroupAdd(cur_p); // store back to shared memory (P matrix) if (sg_inv_id < KV_TILE) { - inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = cur_p; + inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = f16(cur_p); } for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { - o_shmem[q_tile_row * params.head_dim_v + elem_idx] *= cur_exp; + let idx = q_tile_row * params.head_dim_v + elem_idx; + let val = f32(o_shmem[idx]) * cur_exp; + o_shmem[idx] = f16(val); } } @@ -227,7 +232,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &v_shmem, head_dim_block, - true, // or false? is this transposed? + false, // or false? is this transposed? params.head_dim_v ); @@ -247,28 +252,29 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } workgroupBarrier(); + } - // add sinks - for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { - // no need to process rows beyond seq_len_q - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - let max_val = row_max[q_tile_row]; - // for non-sink threads, exp(-65504) effectively zeroes them out - let sink_val = select(-65504.0, f16(sinks[params.offset_sinks + head_idx]), sg_inv_id == 0); - row_max[q_tile_row] = subgroupMax(max_val, sink_val); - let max_exp = exp(max_val - row_max[q_tile_row]); - let sink_exp = exp(sink_val - row_max[q_tile_row]); + // add sinks (applied once after processing all KV tiles) + for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } - exp_sum[q_tile_row] = exp_sum[q_tile_row] * max_exp + subgroupSum(sink_exp); - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { - o_shmem[q_tile_row * params.head_dim_v + elem_idx] *= max_exp; - } + let max_val = row_max[q_tile_row]; + // for non-sink threads, exp(-65504) effectively zeroes them out + let sink_val = select(-3.4e38, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + row_max[q_tile_row] = subgroupMax(max(max_val, sink_val)); + let max_exp = exp(max_val - row_max[q_tile_row]); + let sink_exp = exp(sink_val - row_max[q_tile_row]); - } + exp_sum[q_tile_row] = exp_sum[q_tile_row] * max_exp + subgroupAdd(sink_exp); + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { + let idx = q_tile_row * params.head_dim_v + elem_idx; + let val = f32(o_shmem[idx]) * max_exp; + o_shmem[idx] = f16(val); + } } // write output back to global memory @@ -278,10 +284,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, break; } - let scale = select(0, 1 / exp_sum[q_tile_row], exp_sum[q_tile_row] > 0); + let scale = select(0.0, 1.0 / exp_sum[q_tile_row], exp_sum[q_tile_row] > 0); for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { - dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = f32(o_shmem[q_tile_row * params.head_dim_v + elem_idx] * scale); - } + let o_val = f32(o_shmem[q_tile_row * params.head_dim_v + elem_idx]); + let scaled = o_val * scale; + dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; + } } -} \ No newline at end of file +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5e95e411e92..1e57194d4fb 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1333,10 +1333,12 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); - //for (int i = 0; i < (int) f1.size(); i++) { - // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - //} - //printf("\n"); + // printing first 40 values for flash attention debugging + //for (int i = 0; i < 40; i++) { + for (int i = 0; i < (int) f1.size(); i++) { + printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + } + printf("\n"); //exit(1); ud->ok = false; } From abbc5b2d7940721082fadb989f0741efbcf88583 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 17 Dec 2025 12:47:50 -0800 Subject: [PATCH 27/40] Working first test --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 5 +- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 56 ++++++++++--------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3c926c0f1fc..05d7d07485b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -492,12 +492,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); - - std::cout << "[GPU] debug[0] = " << std::fixed << std::setprecision(6) << debug_data[0] << "\n"; - + std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); } #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index df67b106eee..ddd53415e90 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -50,19 +50,19 @@ struct Params { // The number of Q rows processed per workgroup const Q_TILE = 8u; -var q_shmem: array; // assumes max head_dim_qk of 64 +var q_shmem: array; // assumes max head_dim_qk of 64 const KV_TILE = 8u; // we can reuse the same shmem for K and V since we only need one at a time right? -var k_shmem: array; // assuming max head_dim_qkv of 64 -var v_shmem: array; // assuming max head_dim_qkv of 64 +var k_shmem: array; // assuming max head_dim_qkv of 64 +var v_shmem: array; // assuming max head_dim_qkv of 64 -var o_shmem: array; // output shmem +var o_shmem: array; // output shmem // storage for output of Q*K^T scores for online softmax (S matrix from paper) // also storage for diagonal matrix during online softmax (P matrix from paper) // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; const WG_SIZE = 32u; const SUBGROUP_SIZE = 32u; @@ -73,6 +73,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { + debug[0] = 42; + // each thread maintains its own cache for softmax intermediates var row_max = array(-3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38); var exp_sum = array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); @@ -117,9 +119,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_col = elem_idx % params.head_dim_qk; let head_q_row = q_row_start + q_row; let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; - let q_val: f16 = select( + let q_val: f32 = select( 0.0, - f16(Q[global_q_row_offset + q_col]), + Q[global_q_row_offset + q_col], head_q_row < params.seq_len_q && q_col < params.head_dim_qk); q_shmem[elem_idx] = q_val; } @@ -133,9 +135,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_col = elem_idx % params.head_dim_qk; let global_k_row = kv_tile + k_row; let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - let k_val: f16 = select( + let k_val: f32 = select( 0.0, - f16(K[global_k_row_offset + k_col]), + K[global_k_row_offset + k_col], global_k_row < params.seq_len_kv && k_col < params.head_dim_qk); k_shmem[elem_idx] = k_val; } @@ -143,10 +145,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); // accumulate q block * k block into registers - var acc: subgroup_matrix_result; + var acc: subgroup_matrix_result; for (var head_dim_block = 0u; head_dim_block < params.head_dim_qk; head_dim_block += 8u) { // load q submatrix from shared memory - var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &q_shmem, head_dim_block, false, @@ -154,7 +156,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); // load k submatrix from shared memory - var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &k_shmem, head_dim_block, true, @@ -182,23 +184,23 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // The mask value for this Q row and K col let mask_val = select(0.0, f32(mask[mask_seq_offset + q_tile_row * params.seq_len_kv + kv_tile + sg_inv_id]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); let mask_term = slope * mask_val; - let thread_tile_row_max = select(-3.4e38, f32(inter_shmem[sg_inv_id + q_tile_row * KV_TILE]) * params.scale + mask_term, sg_inv_id < KV_TILE); - row_max[q_tile_row] = subgroupMax(max(row_max[q_tile_row], max(prev_max, thread_tile_row_max))); + let thread_tile_row_max = select(-3.4e38, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * params.scale + mask_term, sg_inv_id < KV_TILE); + row_max[q_tile_row] = subgroupMax(max(prev_max, thread_tile_row_max)); // calculate running exp sum let cur_exp = exp(prev_max - row_max[q_tile_row]); - let cur_p = select(0.0, exp(thread_tile_row_max - row_max[q_tile_row]), sg_inv_id < KV_TILE); + let cur_p = select(0.0, exp(thread_tile_row_max - row_max[q_tile_row]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); exp_sum[q_tile_row] = exp_sum[q_tile_row] * cur_exp + subgroupAdd(cur_p); // store back to shared memory (P matrix) if (sg_inv_id < KV_TILE) { - inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = f16(cur_p); + inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = cur_p; } for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { let idx = q_tile_row * params.head_dim_v + elem_idx; - let val = f32(o_shmem[idx]) * cur_exp; - o_shmem[idx] = f16(val); + let val = o_shmem[idx] * cur_exp; + o_shmem[idx] = val; } } @@ -208,9 +210,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_col = elem_idx % params.head_dim_v; let global_v_row = kv_tile + v_row; let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - let v_val: f16 = select( + let v_val: f32 = select( 0.0, - f16(V[global_v_row_offset + v_col]), + f32(V[global_v_row_offset + v_col]), global_v_row < params.seq_len_kv && v_col < params.head_dim_v); v_shmem[elem_idx] = v_val; } @@ -220,7 +222,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // we have P (8x8 tile, or Q_TILE x KV_TILE) in inter_shmem and V (8 x head_dim_v, or KV_TILE x head_dim_v) in v_shmem // we want to compute O += P * V // load P submatrix from shared memory - var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &inter_shmem, 0, false, @@ -229,7 +231,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var head_dim_block = 0u; head_dim_block < params.head_dim_v; head_dim_block += 8u) { // load V submatrix from shared memory - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &v_shmem, head_dim_block, false, // or false? is this transposed? @@ -237,7 +239,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); // load O submatrix from shared memory - var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( &o_shmem, head_dim_block, false, @@ -272,8 +274,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum[q_tile_row] = exp_sum[q_tile_row] * max_exp + subgroupAdd(sink_exp); for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { let idx = q_tile_row * params.head_dim_v + elem_idx; - let val = f32(o_shmem[idx]) * max_exp; - o_shmem[idx] = f16(val); + let val = o_shmem[idx] * max_exp; + o_shmem[idx] = val; } } @@ -284,10 +286,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, break; } - let scale = select(0.0, 1.0 / exp_sum[q_tile_row], exp_sum[q_tile_row] > 0); + let scale = select(0.0, 1.0 / exp_sum[q_tile_row], exp_sum[q_tile_row] != 0); for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { - let o_val = f32(o_shmem[q_tile_row * params.head_dim_v + elem_idx]); + let o_val = o_shmem[q_tile_row * params.head_dim_v + elem_idx]; let scaled = o_val * scale; dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; } From fd1e3db22247836a9b9e79769bd7b825bdf813ae Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 17 Dec 2025 14:50:25 -0800 Subject: [PATCH 28/40] Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32 --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 151 +++++++++--------- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 54 ++++--- tests/test-backend-ops.cpp | 8 +- 3 files changed, 113 insertions(+), 100 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 05d7d07485b..3ec953df2ce 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -111,7 +111,7 @@ // Flash Attention parameters #define WEBGPU_FLASH_ATTN_WG_SIZE 32 -#define WEBGPU_FLASH_ATTN_Q_TILE 8 +#define WEBGPU_FLASH_ATTN_Q_TILE 8 #define WEBGPU_FLASH_ATTN_KV_TILE 16 /* End Constants */ @@ -298,14 +298,14 @@ struct webgpu_context_struct { webgpu_pipeline flash_attn_pipeline; - std::map> set_rows_pipelines; // dst_type, vectorized - std::map> get_rows_pipelines; // src_type, vectorized + std::map> set_rows_pipelines; // dst_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace + std::map> cpy_pipelines; // src_type, dst_type + std::map> add_pipelines; // type, inplace + std::map> sub_pipelines; // type, inplace + std::map> mul_pipelines; // type, inplace + std::map> div_pipelines; // type, inplace std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace @@ -990,66 +990,68 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, } static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { -// For now we assume everything (mask, sink) - float max_bias; + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + // For now we assume everything (mask, sink) + float scale = *(float *) dst->op_params; + float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + float logit_softcap; + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2]))); float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // print type and dimensions of Q/K/V/mask/sinks/dst - std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] - << ", " << Q->ne[3] << "]\n"; - std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] - << ", " << K->ne[3] << "]\n"; - std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] - << ", " << V->ne[3] << "]\n"; - std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] - << ", " << mask->ne[3] << "]\n"; - std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] - << ", " << sinks->ne[3] << "]\n"; - std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] - << ", " << dst->ne[3] << "]\n"; - - uint32_t offset_q = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)); - uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - uint32_t offset_mask = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)); - uint32_t offset_sinks = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)); - uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)); - - std::cout << "ggml_webgpu_flash_attn: offsets: Q=" << offset_q << ", K=" << offset_k << ", V=" << offset_v - << ", mask=" << offset_mask << ", sinks=" << offset_sinks << ", dst=" << offset_dst << "\n"; + // std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] + // << ", " << Q->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] + // << ", " << K->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] + // << ", " << V->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] + // << ", " << mask->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] + // << ", " << sinks->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] + // << ", " << dst->ne[3] << "]\n"; + + // std::cout << "ggml_webgpu_flash_attn: offsets: Q=" << offset_q << ", K=" << offset_k << ", V=" << offset_v + // << ", mask=" << offset_mask << ", sinks=" << offset_sinks << ", dst=" << offset_dst << "\n"; std::vector params = { - offset_q, - offset_k, - offset_v, - offset_mask, - offset_sinks, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) Q->ne[0], // head dimension (Q/K) - (uint32_t) V->ne[0], // head dimension (V) - (uint32_t) Q->ne[2], // number of heads - (uint32_t) Q->ne[1], // sequence length (Q) - (uint32_t) K->ne[1], // sequence length (K/V) - (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 - (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 - (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 - (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 - (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 - (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 - (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 - (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 - (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - *(uint32_t *) dst->op_params, // scale + (uint32_t) Q->ne[0], // head dimension (Q/K) + (uint32_t) V->ne[0], // head dimension (V) + (uint32_t) Q->ne[2], // number of heads + (uint32_t) Q->ne[1], // sequence length (Q) + (uint32_t) K->ne[1], // sequence length (K/V) + (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 + (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 + (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 + (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 + (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 + (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 + (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 + (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 + (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 + (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)), // stride of mask dim 3 + (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) + *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) *(uint32_t *) &max_bias, + *(uint32_t *) &logit_softcap, *(uint32_t *) &n_head_log2, *(uint32_t *) &m0, *(uint32_t *) &m1 @@ -1059,19 +1061,19 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, { .binding = 0, .buffer = ggml_webgpu_tensor_buf(Q), .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, { .binding = 1, .buffer = ggml_webgpu_tensor_buf(K), .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, { .binding = 2, .buffer = ggml_webgpu_tensor_buf(V), .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + .size = ggml_webgpu_tensor_binding_size(ctx, V) }, { .binding = 3, .buffer = ggml_webgpu_tensor_buf(mask), .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, { .binding = 4, .buffer = ggml_webgpu_tensor_buf(sinks), .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), @@ -1079,16 +1081,16 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, { .binding = 5, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 6, - .buffer = ctx->debug_dev_buf, - .offset = 0, - .size = ctx->debug_dev_buf.GetSize() } + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + // { .binding = 6, + // .buffer = ctx->debug_dev_buf, + // .offset = 0, + // .size = ctx->debug_dev_buf.GetSize() } }; uint32_t wg_per_head = CEIL_DIV(Q->ne[1], WEBGPU_FLASH_ATTN_Q_TILE); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - std::cout << "ggml_webgpu_flash_attn: wg_x: " << wg_x << "\n"; + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + //std::cout << "ggml_webgpu_flash_attn: wg_x: " << wg_x << "\n"; return ggml_backend_webgpu_build(ctx, ctx->flash_attn_pipeline, params, entries, wg_x); } @@ -1546,7 +1548,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str std::vector commands; std::vector futures; - bool contains_flash_attn = false; + bool contains_flash_attn = false; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { if (cgraph->nodes[i]->op == GGML_OP_FLASH_ATTN_EXT) { @@ -1572,8 +1574,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } #ifdef GGML_WEBGPU_DEBUG - if (contains_flash_attn) - ggml_backend_webgpu_debug(ctx); + if (contains_flash_attn) { + ggml_backend_webgpu_debug(ctx); + } #endif ggml_backend_webgpu_wait(ctx, futures); @@ -2594,7 +2597,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_FLASH_ATTN_EXT: - supports_op = true; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && + src2->type == GGML_TYPE_F32 && op->src[3] != nullptr && op->src[4] != nullptr; + supports_op &= op->ne[0] <= 128 && src0->ne[0] <= 128; // max seq len 128 for qkv break; case GGML_OP_RMS_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index ddd53415e90..1dffa129fa7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -28,12 +28,15 @@ struct Params { stride_v1: u32, stride_v2: u32, stride_v3: u32, + stride_mask3: u32, - // TODO: still need to consider broadcast + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, // softmax params scale: f32, max_bias: f32, + logit_softcap: f32, n_head_log2: f32, m0: f32, m1: f32, @@ -45,19 +48,18 @@ struct Params { @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @group(0) @binding(5) var dst: array; -@group(0) @binding(6) var debug: array; -@group(0) @binding(7) var params: Params; +//@group(0) @binding(6) var debug: array; +@group(0) @binding(6) var params: Params; // The number of Q rows processed per workgroup const Q_TILE = 8u; -var q_shmem: array; // assumes max head_dim_qk of 64 +var q_shmem: array; // assumes max head_dim_qk of 64 const KV_TILE = 8u; // we can reuse the same shmem for K and V since we only need one at a time right? -var k_shmem: array; // assuming max head_dim_qkv of 64 -var v_shmem: array; // assuming max head_dim_qkv of 64 +var kv_shmem: array; // assuming max head_dim_qkv of 64 -var o_shmem: array; // output shmem +var o_shmem: array; // output shmem // storage for output of Q*K^T scores for online softmax (S matrix from paper) // also storage for diagonal matrix during online softmax (P matrix from paper) @@ -73,7 +75,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { - debug[0] = 42; + //debug[0] = 42; // each thread maintains its own cache for softmax intermediates var row_max = array(-3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38); @@ -97,19 +99,21 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // head index let head_idx = wg_in_batch / wg_per_head; let q_head_offset = q_batch_offset + head_idx * params.stride_q2; - let k_head_offset = k_batch_offset + head_idx * params.stride_k2; - let v_head_offset = v_batch_offset + head_idx * params.stride_v2; - let wg_in_head = wg_in_batch % wg_per_head; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; let q_row_start = wg_in_head * Q_TILE; + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; + // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * params.head_dim_v; - // Which mask row to use. TODO: support broadcasting - let mask_seq_offset = params.offset_mask + q_row_start * params.seq_len_kv; - let head = f32(head_idx); let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); @@ -126,8 +130,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, q_shmem[elem_idx] = q_val; } - workgroupBarrier(); - for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { // load k tile into shared memory for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { @@ -139,7 +141,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, 0.0, K[global_k_row_offset + k_col], global_k_row < params.seq_len_kv && k_col < params.head_dim_qk); - k_shmem[elem_idx] = k_val; + kv_shmem[elem_idx] = k_val; } workgroupBarrier(); @@ -157,7 +159,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load k submatrix from shared memory var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &k_shmem, + &kv_shmem, head_dim_block, true, params.head_dim_qk @@ -182,9 +184,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // calculate running max let prev_max = row_max[q_tile_row]; // The mask value for this Q row and K col - let mask_val = select(0.0, f32(mask[mask_seq_offset + q_tile_row * params.seq_len_kv + kv_tile + sg_inv_id]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); + let mask_val = select(0.0, f32(mask[mask_global_offset + q_tile_row * params.seq_len_kv + kv_tile + sg_inv_id]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); let mask_term = slope * mask_val; - let thread_tile_row_max = select(-3.4e38, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * params.scale + mask_term, sg_inv_id < KV_TILE); + var thread_tile_row_max = select(-3.4e38, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * params.scale, sg_inv_id < KV_TILE); + if (params.logit_softcap != 0.0) { + thread_tile_row_max = params.logit_softcap * tanh(thread_tile_row_max); + } + thread_tile_row_max += mask_term; row_max[q_tile_row] = subgroupMax(max(prev_max, thread_tile_row_max)); // calculate running exp sum @@ -214,12 +220,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, 0.0, f32(V[global_v_row_offset + v_col]), global_v_row < params.seq_len_kv && v_col < params.head_dim_v); - v_shmem[elem_idx] = v_val; + kv_shmem[elem_idx] = v_val; } workgroupBarrier(); - // we have P (8x8 tile, or Q_TILE x KV_TILE) in inter_shmem and V (8 x head_dim_v, or KV_TILE x head_dim_v) in v_shmem + // we have P (8x8 tile, or Q_TILE x KV_TILE) in inter_shmem and V (8 x head_dim_v, or KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V // load P submatrix from shared memory var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( @@ -232,7 +238,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var head_dim_block = 0u; head_dim_block < params.head_dim_v; head_dim_block += 8u) { // load V submatrix from shared memory var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &v_shmem, + &kv_shmem, head_dim_block, false, // or false? is this transposed? params.head_dim_v @@ -279,6 +285,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } + workgroupBarrier(); + // write output back to global memory for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { let global_q_row = q_row_start + q_tile_row; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1e57194d4fb..c28813ceb03 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1335,10 +1335,10 @@ struct test_case { printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); // printing first 40 values for flash attention debugging //for (int i = 0; i < 40; i++) { - for (int i = 0; i < (int) f1.size(); i++) { - printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); - } - printf("\n"); + //for (int i = 0; i < (int) f1.size(); i++) { + // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); + //} + //printf("\n"); //exit(1); ud->ok = false; } From 2f39c2a5270e41c487359ff1870fd445a049a183 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 19 Dec 2025 10:00:26 -0800 Subject: [PATCH 29/40] Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 356 +++++++++++------- 2 files changed, 218 insertions(+), 142 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3ec953df2ce..f4ceedd691f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -111,8 +111,8 @@ // Flash Attention parameters #define WEBGPU_FLASH_ATTN_WG_SIZE 32 -#define WEBGPU_FLASH_ATTN_Q_TILE 8 -#define WEBGPU_FLASH_ATTN_KV_TILE 16 +#define WEBGPU_FLASH_ATTN_Q_TILE 16 +#define WEBGPU_FLASH_ATTN_KV_TILE 8 /* End Constants */ diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 1dffa129fa7..7be0bbc0107 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -1,4 +1,5 @@ diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; @@ -51,35 +52,57 @@ struct Params { //@group(0) @binding(6) var debug: array; @group(0) @binding(6) var params: Params; +const FLOAT_MIN: f16 = -65504.0; + // The number of Q rows processed per workgroup -const Q_TILE = 8u; -var q_shmem: array; // assumes max head_dim_qk of 64 +const Q_TILE = 16u; +var q_shmem: array; // assumes max head_dim_qk of 128 -const KV_TILE = 8u; +const KV_TILE = 16u; // we can reuse the same shmem for K and V since we only need one at a time right? -var kv_shmem: array; // assuming max head_dim_qkv of 64 +var kv_shmem: array; // assuming max head_dim_v of 128 + +var o_shmem: array; // output shmem -var o_shmem: array; // output shmem +// storage for mask values +var mask_shmem: array; // storage for output of Q*K^T scores for online softmax (S matrix from paper) // also storage for diagonal matrix during online softmax (P matrix from paper) // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; -const WG_SIZE = 32u; -const SUBGROUP_SIZE = 32u; +// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN +// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. +const SG_MAT_M = 8u; +const SG_MAT_N = 8u; +const SG_MAT_K = 8u; + +// Number of blocks this workgroup handles at the subgroup matrix level. SG_MAT_M must divide Q_TILE. +const Q_BLOCKS = Q_TILE / SG_MAT_M; +// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. +const KV_BLOCKS = KV_TILE / SG_MAT_N; + +const WG_SIZE = 64u; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { //debug[0] = 42; - // each thread maintains its own cache for softmax intermediates - var row_max = array(-3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38, -3.4e38); - var exp_sum = array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0); + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + } // workgroups per head/batch let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; @@ -115,7 +138,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * params.head_dim_v; let head = f32(head_idx); - let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0); + let slope = f16(select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0)); // load q tile into shared memory for (var elem_idx = local_id.x; elem_idx < Q_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { @@ -123,11 +146,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_col = elem_idx % params.head_dim_qk; let head_q_row = q_row_start + q_row; let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; - let q_val: f32 = select( + q_shmem[elem_idx] = f16(select( 0.0, Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < params.head_dim_qk); - q_shmem[elem_idx] = q_val; + head_q_row < params.seq_len_q && q_col < params.head_dim_qk)); } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { @@ -137,76 +159,107 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_col = elem_idx % params.head_dim_qk; let global_k_row = kv_tile + k_row; let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - let k_val: f32 = select( + kv_shmem[elem_idx] = f16(select( 0.0, K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < params.head_dim_qk); - kv_shmem[elem_idx] = k_val; + global_k_row < params.seq_len_kv && k_col < params.head_dim_qk)); } workgroupBarrier(); - // accumulate q block * k block into registers - var acc: subgroup_matrix_result; - for (var head_dim_block = 0u; head_dim_block < params.head_dim_qk; head_dim_block += 8u) { - // load q submatrix from shared memory - var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( - &q_shmem, - head_dim_block, - false, - params.head_dim_qk - ); - - // load k submatrix from shared memory - var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &kv_shmem, - head_dim_block, - true, - params.head_dim_qk - ); - - acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); + // accumulate q block * k block into registers across the entire KV tile + for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { + let q_block_offset = sg_block * SG_MAT_M * params.head_dim_qk; + for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { + var acc: subgroup_matrix_result; + let k_block_offset = kv_block * SG_MAT_N * params.head_dim_qk; + for (var head_dim_block = 0u; head_dim_block < params.head_dim_qk; head_dim_block += SG_MAT_K) { + // load q submatrix from shared memory + var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &q_shmem, + q_block_offset + head_dim_block, + false, + params.head_dim_qk + ); + + // load k submatrix from shared memory + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &kv_shmem, + k_block_offset + head_dim_block, + true, + params.head_dim_qk + ); + + acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); + } + + // store acc to shared memory for softmax + let inter_offset = sg_block * SG_MAT_M * KV_TILE + kv_block * SG_MAT_N; + subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + } } - // store acc to shared memory for softmax - subgroupMatrixStore(&inter_shmem, 0, acc, false, KV_TILE); + // load mask tile into shared memory for this KV block + // TODO: optimize and skip if mask is -INF for the entire tile + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } workgroupBarrier(); // online softmax - for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { - // no need to process rows beyond seq_len_q - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - // calculate running max - let prev_max = row_max[q_tile_row]; - // The mask value for this Q row and K col - let mask_val = select(0.0, f32(mask[mask_global_offset + q_tile_row * params.seq_len_kv + kv_tile + sg_inv_id]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); - let mask_term = slope * mask_val; - var thread_tile_row_max = select(-3.4e38, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * params.scale, sg_inv_id < KV_TILE); - if (params.logit_softcap != 0.0) { - thread_tile_row_max = params.logit_softcap * tanh(thread_tile_row_max); - } - thread_tile_row_max += mask_term; - row_max[q_tile_row] = subgroupMax(max(prev_max, thread_tile_row_max)); - - // calculate running exp sum - let cur_exp = exp(prev_max - row_max[q_tile_row]); - let cur_p = select(0.0, exp(thread_tile_row_max - row_max[q_tile_row]), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); - exp_sum[q_tile_row] = exp_sum[q_tile_row] * cur_exp + subgroupAdd(cur_p); - - // store back to shared memory (P matrix) - if (sg_inv_id < KV_TILE) { - inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = cur_p; - } - - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { - let idx = q_tile_row * params.head_dim_v + elem_idx; - let val = o_shmem[idx] * cur_exp; - o_shmem[idx] = val; + for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { + let block_row_start = sg_block * SG_MAT_M; + let block_row_end = block_row_start + SG_MAT_M; + for (var q_tile_row = block_row_start; q_tile_row < block_row_end; q_tile_row++) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + // calculate running max + // only the first thread in the subgroup needs to read from shared memory. + // TODO: is this faster than having all threads read shared memory? + var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); + prev_max = subgroupBroadcastFirst(prev_max); + // The mask value for this Q row and K col + let mask_val = select(0.0, mask_shmem[q_tile_row * KV_TILE + sg_inv_id], sg_inv_id < KV_TILE); + let mask_term = slope * mask_val; + var thread_tile_row_max = select(FLOAT_MIN, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * f16(params.scale), sg_inv_id < KV_TILE); + if (params.logit_softcap != 0.0) { + thread_tile_row_max = f16(params.logit_softcap) * tanh(thread_tile_row_max); + } + thread_tile_row_max += mask_term; + let new_max = subgroupMax(max(prev_max, thread_tile_row_max)); + + // calculate running exp sum + let cur_p = select(0.0, exp(thread_tile_row_max - new_max), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); + let new_exp_term = subgroupAdd(cur_p); + + // store back to shared memory (P matrix) + if (sg_inv_id < KV_TILE) { + inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = cur_p; + } + + let cur_exp = exp(prev_max - new_max); + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = new_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + new_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { + let idx = q_tile_row * params.head_dim_v + elem_idx; + let val = o_shmem[idx] * cur_exp; + o_shmem[idx] = val; + } } } @@ -216,90 +269,113 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_col = elem_idx % params.head_dim_v; let global_v_row = kv_tile + v_row; let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - let v_val: f32 = select( + kv_shmem[elem_idx] = f16(select( 0.0, - f32(V[global_v_row_offset + v_col]), - global_v_row < params.seq_len_kv && v_col < params.head_dim_v); - kv_shmem[elem_idx] = v_val; + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < params.head_dim_v)); } workgroupBarrier(); - // we have P (8x8 tile, or Q_TILE x KV_TILE) in inter_shmem and V (8 x head_dim_v, or KV_TILE x head_dim_v) in kv_shmem - // we want to compute O += P * V - // load P submatrix from shared memory - var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( - &inter_shmem, - 0, - false, - KV_TILE - ); - - for (var head_dim_block = 0u; head_dim_block < params.head_dim_v; head_dim_block += 8u) { - // load V submatrix from shared memory - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &kv_shmem, - head_dim_block, - false, // or false? is this transposed? - params.head_dim_v - ); - - // load O submatrix from shared memory - var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( - &o_shmem, - head_dim_block, - false, - params.head_dim_v - ); - - // O += P * V - o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); - - // store O back to shared memory - subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, params.head_dim_v); + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { + let o_row_offset = sg_block * SG_MAT_M * params.head_dim_v; + for (var head_dim_block = 0u; head_dim_block < params.head_dim_v; head_dim_block += SG_MAT_N) { + // load O submatrix from shared memory + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + &o_shmem, + o_row_offset + head_dim_block, + false, + params.head_dim_v + ); + + for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { + let p_offset = sg_block * SG_MAT_M * KV_TILE + kv_block * SG_MAT_N; + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &inter_shmem, + p_offset, + false, + KV_TILE + ); + + // load V submatrix from shared memory + let v_block_offset = kv_block * SG_MAT_N * params.head_dim_v; + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &kv_shmem, + v_block_offset + head_dim_block, + false, + params.head_dim_v + ); + + // O += P * V + o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); + } + + // store O back to shared memory + subgroupMatrixStore(&o_shmem, o_row_offset + head_dim_block, o_sg_mat, false, params.head_dim_v); + } } workgroupBarrier(); } // add sinks (applied once after processing all KV tiles) - for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { - // no need to process rows beyond seq_len_q - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - let max_val = row_max[q_tile_row]; - // for non-sink threads, exp(-65504) effectively zeroes them out - let sink_val = select(-3.4e38, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); - row_max[q_tile_row] = subgroupMax(max(max_val, sink_val)); - let max_exp = exp(max_val - row_max[q_tile_row]); - let sink_exp = exp(sink_val - row_max[q_tile_row]); - - exp_sum[q_tile_row] = exp_sum[q_tile_row] * max_exp + subgroupAdd(sink_exp); - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { - let idx = q_tile_row * params.head_dim_v + elem_idx; - let val = o_shmem[idx] * max_exp; - o_shmem[idx] = val; + for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { + let block_row_start = sg_block * SG_MAT_M; + let block_row_end = block_row_start + SG_MAT_M; + for (var q_tile_row = block_row_start; q_tile_row < block_row_end; q_tile_row++) { + // no need to process rows beyond seq_len_q + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); + prev_max = subgroupBroadcastFirst(prev_max); + + // for non-sink threads, exp(-65504) effectively zeroes out their contrinbution to the sum + let sink_val = select(FLOAT_MIN, f16(sinks[params.offset_sinks + head_idx]), sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { + let idx = q_tile_row * params.head_dim_v + elem_idx; + let val = o_shmem[idx] * max_exp; + o_shmem[idx] = val; + } } } workgroupBarrier(); // write output back to global memory - for (var q_tile_row = 0u; q_tile_row < Q_TILE; q_tile_row++) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - let scale = select(0.0, 1.0 / exp_sum[q_tile_row], exp_sum[q_tile_row] != 0); - - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += SUBGROUP_SIZE) { - let o_val = o_shmem[q_tile_row * params.head_dim_v + elem_idx]; - let scaled = o_val * scale; - dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; + for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { + let block_row_start = sg_block * SG_MAT_M; + let block_row_end = block_row_start + SG_MAT_M; + for (var q_tile_row = block_row_start; q_tile_row < block_row_end; q_tile_row++) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var exp_sum = select(0.0, exp_sum_shmem[q_tile_row], sg_inv_id == 0); + exp_sum = subgroupBroadcastFirst(exp_sum); + + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0); + + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { + let o_val = o_shmem[q_tile_row * params.head_dim_v + elem_idx]; + let scaled = o_val * scale; + dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = f32(scaled); + } } } } From efd49e1d7e8615d7191f69063d4278611f8fe3cd Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 29 Dec 2025 10:51:37 -0800 Subject: [PATCH 30/40] Start work on integrating pre-wgsl --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 187 ++++-- ggml/src/ggml-webgpu/pre_wgsl.hpp | 619 ++++++++++++++++++ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 7 +- 3 files changed, 773 insertions(+), 40 deletions(-) create mode 100644 ggml/src/ggml-webgpu/pre_wgsl.hpp diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f4ceedd691f..286413894f6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" +#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -18,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -269,6 +271,67 @@ struct webgpu_command { #endif }; +// Pipeline keys + +template +static inline void webgpu_hash_combine(size_t & seed, const T & value) { + seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +static inline const char * ggml_webgpu_wgsl_kv_type(ggml_type type) { + switch (type) { + case GGML_TYPE_F16: return "f16"; + case GGML_TYPE_F32: return "f32"; + default: return nullptr; + } +} + +struct flash_attn_pipeline_key { + int q_type; + int kv_type; + int mask_type; + int sinks_type; + int dst_type; + uint32_t head_dim_q; + uint32_t head_dim_v; + uint32_t n_heads; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + + bool operator==(const flash_attn_pipeline_key & other) const { + return q_type == other.q_type && + kv_type == other.kv_type && + mask_type == other.mask_type && + sinks_type == other.sinks_type && + dst_type == other.dst_type && + head_dim_q == other.head_dim_q && + head_dim_v == other.head_dim_v && + n_heads == other.n_heads && + has_mask == other.has_mask && + has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap; + } +}; + +struct flash_attn_pipeline_key_hash { + size_t operator()(const flash_attn_pipeline_key & key) const { + size_t seed = 0; + webgpu_hash_combine(seed, key.q_type); + webgpu_hash_combine(seed, key.kv_type); + webgpu_hash_combine(seed, key.mask_type); + webgpu_hash_combine(seed, key.sinks_type); + webgpu_hash_combine(seed, key.dst_type); + webgpu_hash_combine(seed, key.head_dim_q); + webgpu_hash_combine(seed, key.head_dim_v); + webgpu_hash_combine(seed, key.n_heads); + webgpu_hash_combine(seed, key.has_mask); + webgpu_hash_combine(seed, key.has_sinks); + webgpu_hash_combine(seed, key.uses_logit_softcap); + return seed; + } +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; @@ -290,13 +353,16 @@ struct webgpu_context_struct { webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; + pre_wgsl::Preprocessor p; + std::map memset_pipelines; // variant or type index std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - webgpu_pipeline flash_attn_pipeline; + std::unordered_map + flash_attn_pipelines; std::map> set_rows_pipelines; // dst_type, vectorized std::map> get_rows_pipelines; // src_type, vectorized @@ -373,6 +439,21 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ +static void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label) { + wgpu::BufferDescriptor buffer_desc; + buffer_desc.size = size; + buffer_desc.usage = usage; + buffer_desc.label = label; + buffer_desc.mappedAtCreation = false; + + // TODO: error handling + buffer = device.CreateBuffer(&buffer_desc); +} + // Process a WGSL shader string, replacing tokens of the form {{KEY}} with // the corresponding values provided in `repls`. static std::string ggml_webgpu_process_shader_repls(const char * src, @@ -416,19 +497,45 @@ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & return { device.CreateComputePipeline(&pipeline_desc), label }; } -static void ggml_webgpu_create_buffer(wgpu::Device & device, - wgpu::Buffer & buffer, - size_t size, - wgpu::BufferUsage usage, - const char * label) { - wgpu::BufferDescriptor buffer_desc; - buffer_desc.size = size; - buffer_desc.usage = usage; - buffer_desc.label = label; - buffer_desc.mappedAtCreation = false; +static webgpu_pipeline ggml_webgpu_get_flash_attn_pipeline(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst, + float logit_softcap) { + GGML_ASSERT(K->type == V->type); + + flash_attn_pipeline_key key = { + .q_type = Q->type, + .kv_type = K->type, + .mask_type = mask->type, + .sinks_type = sinks->type, + .dst_type = dst->type, + .head_dim_q = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .n_heads = (uint32_t) Q->ne[2], + .has_mask = true, + .has_sinks = true, + .uses_logit_softcap = logit_softcap != 0.0f, + }; - // TODO: error handling - buffer = device.CreateBuffer(&buffer_desc); + if (ctx->flash_attn_pipelines.count(key)) { + return ctx->flash_attn_pipelines[key]; + } + + std::lock_guard lock(ctx->mutex); + if (ctx->flash_attn_pipelines.count(key)) { + return ctx->flash_attn_pipelines[key]; + } + + const char * kv_type = ggml_webgpu_wgsl_kv_type(K->type); + std::string label = std::string("flash_attn_kv_") + kv_type; + std::string shader = ctx->p.preprocess(wgsl_flash_attn, { std::string("KV_TYPE=") + kv_type }); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(ctx->device, shader.c_str(), label.c_str()); + ctx->flash_attn_pipelines.emplace(key, pipeline); + return pipeline; } /** End WebGPU object initializations */ @@ -1009,22 +1116,19 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - // print type and dimensions of Q/K/V/mask/sinks/dst - // std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] - // << ", " << Q->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] - // << ", " << K->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] - // << ", " << V->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] - // << ", " << mask->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] - // << ", " << sinks->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] - // << ", " << dst->ne[3] << "]\n"; - - // std::cout << "ggml_webgpu_flash_attn: offsets: Q=" << offset_q << ", K=" << offset_k << ", V=" << offset_v - // << ", mask=" << offset_mask << ", sinks=" << offset_sinks << ", dst=" << offset_dst << "\n"; + // print type and dimensions of Q/K/V/mask/sinks/dst +// std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] +// << ", " << Q->ne[3] << "]\n"; +// std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] +// << ", " << K->ne[3] << "]\n"; +// std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] +// << ", " << V->ne[3] << "]\n"; +// std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] +// << ", " << mask->ne[3] << "]\n"; +// std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] +// << ", " << sinks->ne[3] << "]\n"; +// std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] +// << ", " << dst->ne[3] << "]\n"; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), @@ -1088,10 +1192,12 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, // .size = ctx->debug_dev_buf.GetSize() } }; + webgpu_pipeline pipeline = ggml_webgpu_get_flash_attn_pipeline(ctx, Q, K, V, mask, sinks, dst, logit_softcap); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], WEBGPU_FLASH_ATTN_Q_TILE); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches //std::cout << "ggml_webgpu_flash_attn: wg_x: " << wg_x << "\n"; - return ggml_backend_webgpu_build(ctx, ctx->flash_attn_pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -2007,10 +2113,6 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } -static void ggml_webgpu_init_flash_attn_pipeline(webgpu_context & webgpu_ctx) { - webgpu_ctx->flash_attn_pipeline = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_flash_attn, "flash_attn"); -} - static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline( webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); @@ -2597,10 +2699,20 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_FLASH_ATTN_EXT: - supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && - src2->type == GGML_TYPE_F32 && op->src[3] != nullptr && op->src[4] != nullptr; - supports_op &= op->ne[0] <= 128 && src0->ne[0] <= 128; // max seq len 128 for qkv + { + supports_op = true; + // Q-type + supports_op &= src0->type == GGML_TYPE_F32; + // KV-type + supports_op &= src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16; + // Mask-type + supports_op &= op->src[3] != nullptr; + // Sink-type + supports_op &= op->src[4] != nullptr; + // qkv sequence length + supports_op &= op->ne[0] <= 128 && src0->ne[0] <= 128; break; + } case GGML_OP_RMS_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; @@ -2874,7 +2986,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_memset_pipeline(ctx); ggml_webgpu_init_mul_mat_pipeline(ctx); - ggml_webgpu_init_flash_attn_pipeline(ctx); ggml_webgpu_init_set_rows_pipeline(ctx); ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp new file mode 100644 index 00000000000..70dc19bf58c --- /dev/null +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -0,0 +1,619 @@ +#ifndef PRE_WGSL_HPP +#define PRE_WGSL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace pre_wgsl { + +//============================================================== +// Options +//============================================================== +struct Options { + std::string include_path = "."; + std::vector macros; +}; + +//============================================================== +// Utility: trim +//============================================================== +static inline std::string trim(const std::string& s) { + size_t a = 0; + while (a < s.size() && std::isspace((unsigned char)s[a])) a++; + size_t b = s.size(); + while (b > a && std::isspace((unsigned char)s[b - 1])) b--; + return s.substr(a, b - a); +} + +static inline std::string trim_value(std::istream& is) { + std::string str; + std::getline(is, str); + return trim(str); +} + +//============================================================== +// Tokenizer for expressions in #if/#elif +//============================================================== +class ExprLexer { +public: + enum Kind { + END, IDENT, NUMBER, + OP, LPAREN, RPAREN + }; + + struct Tok { + Kind kind; + std::string text; + }; + + explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} + + Tok next() { + skipWS(); + if (pos >= src.size()) return { END, "" }; + + char c = src[pos]; + + // number + if (std::isdigit((unsigned char)c)) { + size_t start = pos; + while (pos < src.size() && std::isdigit((unsigned char)src[pos])) + pos++; + return { NUMBER, std::string(src.substr(start, pos - start)) }; + } + + // identifier + if (std::isalpha((unsigned char)c) || c == '_') { + size_t start = pos; + while (pos < src.size() && + (std::isalnum((unsigned char)src[pos]) || src[pos] == '_')) + pos++; + return { IDENT, std::string(src.substr(start, pos - start)) }; + } + + if (c == '(') { pos++; return { LPAREN, "(" }; } + if (c == ')') { pos++; return { RPAREN, ")" }; } + + // multi-char operators + static const char* two_ops[] = { + "==","!=", "<=", ">=", "&&","||", "<<",">>" + }; + for (auto op : two_ops) { + if (src.substr(pos, 2) == op) { + pos += 2; + return { OP, std::string(op) }; + } + } + + // single-char operators + if (std::string("+-*/%<>!").find(c) != std::string::npos) { + pos++; + return { OP, std::string(1, c) }; + } + + // unexpected + pos++; + return { END, "" }; + } + +private: + std::string_view src; + size_t pos; + + void skipWS() { + while (pos < src.size() && std::isspace((unsigned char)src[pos])) + pos++; + } +}; + +//============================================================== +// Expression Parser (recursive descent) +//============================================================== +class ExprParser { +public: + ExprParser(std::string_view expr, + const std::unordered_map& macros) + : lex(expr), macros(macros) + { + advance(); + } + + int parse() { + return parseLogicalOr(); + } + +private: + ExprLexer lex; + ExprLexer::Tok tok; + const std::unordered_map& macros; + + void advance() { tok = lex.next(); } + + bool acceptOp(const std::string& s) { + if (tok.kind == ExprLexer::OP && tok.text == s) { + advance(); + return true; + } + return false; + } + + bool acceptKind(ExprLexer::Kind k) { + if (tok.kind == k) { + advance(); + return true; + } + return false; + } + + int parseLogicalOr() { + int v = parseLogicalAnd(); + while (acceptOp("||")) { + int rhs = parseLogicalAnd(); + v = (v || rhs); + } + return v; + } + + int parseLogicalAnd() { + int v = parseEquality(); + while (acceptOp("&&")) { + int rhs = parseEquality(); + v = (v && rhs); + } + return v; + } + + int parseEquality() { + int v = parseRelational(); + for (;;) { + if (acceptOp("==")) { + int rhs = parseRelational(); + v = (v == rhs); + } else if (acceptOp("!=")) { + int rhs = parseRelational(); + v = (v != rhs); + } else break; + } + return v; + } + + int parseRelational() { + int v = parseShift(); + for (;;) { + if (acceptOp("<")) { int rhs = parseShift(); v = (v < rhs); } + else if (acceptOp(">")) { int rhs = parseShift(); v = (v > rhs); } + else if (acceptOp("<=")){ int rhs = parseShift(); v = (v <= rhs); } + else if (acceptOp(">=")){ int rhs = parseShift(); v = (v >= rhs); } + else break; + } + return v; + } + + int parseShift() { + int v = parseAdd(); + for (;;) { + if (acceptOp("<<")) { int rhs = parseAdd(); v = (v << rhs); } + else if (acceptOp(">>")) { int rhs = parseAdd(); v = (v >> rhs); } + else break; + } + return v; + } + + int parseAdd() { + int v = parseMult(); + for (;;) { + if (acceptOp("+")) { int rhs = parseMult(); v = (v + rhs); } + else if (acceptOp("-")) { int rhs = parseMult(); v = (v - rhs); } + else break; + } + return v; + } + + int parseMult() { + int v = parseUnary(); + for (;;) { + if (acceptOp("*")) { int rhs = parseUnary(); v = (v * rhs); } + else if (acceptOp("/")) { int rhs = parseUnary(); v = (rhs == 0 ? 0 : v / rhs); } + else if (acceptOp("%")) { int rhs = parseUnary(); v = (rhs == 0 ? 0 : v % rhs); } + else break; + } + return v; + } + + int parseUnary() { + if (acceptOp("!")) return !parseUnary(); + if (acceptOp("-")) return -parseUnary(); + if (acceptOp("+")) return +parseUnary(); + return parsePrimary(); + } + + int parsePrimary() { + // '(' expr ')' + if (acceptKind(ExprLexer::LPAREN)) { + int v = parse(); + if (!acceptKind(ExprLexer::RPAREN)) + throw std::runtime_error("missing ')'"); + return v; + } + + // number + if (tok.kind == ExprLexer::NUMBER) { + int v = std::stoi(tok.text); + advance(); + return v; + } + + // defined(identifier) + if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { + advance(); + if (acceptKind(ExprLexer::LPAREN)) { + if (tok.kind != ExprLexer::IDENT) + throw std::runtime_error("expected identifier in defined()"); + std::string name = tok.text; + advance(); + if (!acceptKind(ExprLexer::RPAREN)) + throw std::runtime_error("missing ) in defined()"); + return macros.count(name) ? 1 : 0; + } else { + // defined NAME + if (tok.kind != ExprLexer::IDENT) + throw std::runtime_error("expected identifier in defined NAME"); + std::string name = tok.text; + advance(); + return macros.count(name) ? 1 : 0; + } + } + + // identifier -> treat as integer, if defined use its value else 0 + if (tok.kind == ExprLexer::IDENT) { + std::string name = tok.text; + advance(); + auto it = macros.find(name); + if (it == macros.end()) return 0; + if (it->second.empty()) return 1; + return std::stoi(it->second); + } + + // unexpected + return 0; + } +}; + +//============================================================== +// Preprocessor +//============================================================== +class Preprocessor { +public: + explicit Preprocessor(Options opts = {}) + : opts_(std::move(opts)) { + // Treat empty include path as current directory + if (opts_.include_path.empty()) { + opts_.include_path = "."; + } + parseMacroDefinitions(opts_.macros); + } + + std::string preprocess_file(const std::string& filename, + const std::vector& additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processFile(filename, macros, predefined, include_stack); + return result; + } + + std::string preprocess(const std::string& contents, + const std::vector& additional_macros = {}) + { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processString(contents, macros, predefined, include_stack); + return result; + } + +private: + Options opts_; + std::unordered_map global_macros; + + struct Cond { + bool parent_active; + bool active; + bool taken; + }; + + //---------------------------------------------------------- + // Parse macro definitions into global_macros + //---------------------------------------------------------- + void parseMacroDefinitions(const std::vector& macro_defs) { + for (const auto& def : macro_defs) { + size_t eq_pos = def.find('='); + if (eq_pos != std::string::npos) { + // Format: NAME=VALUE + std::string name = trim(def.substr(0, eq_pos)); + std::string value = trim(def.substr(eq_pos + 1)); + global_macros[name] = value; + } else { + // Format: NAME + std::string name = trim(def); + global_macros[name] = ""; + } + } + } + + //---------------------------------------------------------- + // Build combined macro map and predefined set for a preprocessing operation + //---------------------------------------------------------- + void buildMacros( + const std::vector& additional_macros, + std::unordered_map& macros, + std::unordered_set& predefined) { + macros = global_macros; + predefined.clear(); + + for (const auto& [name, value] : global_macros) { + predefined.insert(name); + } + + for (const auto& def : additional_macros) { + size_t eq_pos = def.find('='); + std::string name, value; + if (eq_pos != std::string::npos) { + name = trim(def.substr(0, eq_pos)); + value = trim(def.substr(eq_pos + 1)); + } else { + name = trim(def); + value = ""; + } + + // Add to macros map (will override global if same name) + macros[name] = value; + predefined.insert(name); + } + } + + //---------------------------------------------------------- + // Helpers + //---------------------------------------------------------- + std::string loadFile(const std::string& fname) { + std::ifstream f(fname); + if (!f.is_open()) + throw std::runtime_error("Could not open file: " + fname); + std::stringstream ss; + ss << f.rdbuf(); + return ss.str(); + } + + bool condActive(const std::vector& cond) const { + if (cond.empty()) return true; + return cond.back().active; + } + + //---------------------------------------------------------- + // Helper to check if a character can be part of an identifier + //---------------------------------------------------------- + static bool isIdent(char c) { + return std::isalnum(static_cast(c)) || c == '_'; + } + + //---------------------------------------------------------- + // Expand macros in a line of code + //---------------------------------------------------------- + std::string expandMacros(const std::string& line, + const std::unordered_map& macros) { + std::string result; + result.reserve(line.size()); + + size_t i = 0; + while (i < line.size()) { + if (isIdent(line[i])) { + size_t start = i; + while (i < line.size() && isIdent(line[i])) { + i++; + } + std::string token = line.substr(start, i - start); + + auto it = macros.find(token); + if (it != macros.end()) { + result += it->second; + } else { + result += token; + } + } else { + result += line[i]; + i++; + } + } + + return result; + } + + //---------------------------------------------------------- + // Process a file + //---------------------------------------------------------- + std::string processFile(const std::string& name, + std::unordered_map& macros, + const std::unordered_set& predefined_macros, + std::unordered_set& include_stack) { + if (include_stack.count(name)) + throw std::runtime_error("Recursive include: " + name); + + include_stack.insert(name); + std::string shader_code = loadFile(name); + std::string out = processString(shader_code, macros, predefined_macros, include_stack); + include_stack.erase(name); + return out; + } + + std::string processIncludeFile(const std::string& fname, + std::unordered_map& macros, + const std::unordered_set& predefined_macros, + std::unordered_set& include_stack) { + std::string full_path = opts_.include_path + "/" + fname; + return processFile(full_path, macros, predefined_macros, include_stack); + } + + //---------------------------------------------------------- + // Process text + //---------------------------------------------------------- + std::string processString(const std::string& shader_code, + std::unordered_map& macros, + const std::unordered_set& predefined_macros, + std::unordered_set& include_stack) + { + std::vector cond; // Conditional stack for this shader + std::stringstream out; + std::istringstream in(shader_code); + std::string line; + + while (std::getline(in, line)) { + std::string t = trim(line); + + if (!t.empty() && t[0] == '#') { + handleDirective(t, out, macros, predefined_macros, cond, include_stack); + } else { + if (condActive(cond)) { + // Expand macros in the line before outputting + std::string expanded = expandMacros(line, macros); + out << expanded << "\n"; + } + } + } + + if (!cond.empty()) + throw std::runtime_error("Unclosed #if directive"); + + return out.str(); + } + + //---------------------------------------------------------- + // Directive handler + //---------------------------------------------------------- + void handleDirective(const std::string& t, std::stringstream& out, + std::unordered_map& macros, + const std::unordered_set& predefined_macros, + std::vector& cond, + std::unordered_set& include_stack) { + // split into tokens + std::string body = t.substr(1); + std::istringstream iss(body); + std::string cmd; + iss >> cmd; + + if (cmd == "include") { + if (!condActive(cond)) return; + std::string file; + iss >> file; + if (file.size() >= 2 && file.front()=='"' && file.back()=='"') + file = file.substr(1, file.size()-2); + out << processIncludeFile(file, macros, predefined_macros, include_stack); + return; + } + + if (cmd == "define") { + if (!condActive(cond)) return; + std::string name; + iss >> name; + // Don't override predefined macros from options + if (predefined_macros.count(name)) return; + std::string value = trim_value(iss); + macros[name] = value; + return; + } + + if (cmd == "ifdef") { + std::string name; iss >> name; + bool p = condActive(cond); + bool v = macros.count(name); + cond.push_back({p, p && v, p && v}); + return; + } + + if (cmd == "ifndef") { + std::string name; iss >> name; + bool p = condActive(cond); + bool v = !macros.count(name); + cond.push_back({p, p && v, p && v}); + return; + } + + if (cmd == "if") { + std::string expr = trim_value(iss); + bool p = condActive(cond); + bool v = false; + if (p) { + ExprParser ep(expr, macros); + v = ep.parse() != 0; + } + cond.push_back({p, p && v, p && v}); + return; + } + + if (cmd == "elif") { + std::string expr = trim_value(iss); + + if (cond.empty()) + throw std::runtime_error("#elif without #if"); + + Cond& c = cond.back(); + if (!c.parent_active) { + c.active = false; + return; + } + + if (c.taken) { + c.active = false; + return; + } + + ExprParser ep(expr, macros); + bool v = ep.parse() != 0; + c.active = v; + if (v) c.taken = true; + return; + } + + if (cmd == "else") { + if (cond.empty()) + throw std::runtime_error("#else without #if"); + + Cond& c = cond.back(); + if (!c.parent_active) { + c.active = false; + return; + } + if (c.taken) { + c.active = false; + } else { + c.active = true; + c.taken = true; + } + return; + } + + if (cmd == "endif") { + if (cond.empty()) + throw std::runtime_error("#endif without #if"); + cond.pop_back(); + return; + } + + // Unknown directive + throw std::runtime_error("Unknown directive: #" + cmd); + } +}; + +} // namespace pre_wgsl + +#endif // PRE_WGSL_HPP \ No newline at end of file diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 7be0bbc0107..2fe434b1b66 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,6 +4,9 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; +// Default values +#define KV_TYPE f32 + struct Params { offset_q: u32, offset_k: u32, @@ -44,8 +47,8 @@ struct Params { }; @group(0) @binding(0) var Q: array; -@group(0) @binding(1) var K: array; -@group(0) @binding(2) var V: array; +@group(0) @binding(1) var K: array; +@group(0) @binding(2) var V: array; @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; @group(0) @binding(5) var dst: array; From 1dc20ce952d49dc8aba14e4c1ad08d472a7b22ac Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 29 Dec 2025 16:31:24 -0800 Subject: [PATCH 31/40] Separate structs/initial shader compilation library into separate files --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 70 +++ ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp | 303 ++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 460 +++--------------- 3 files changed, 429 insertions(+), 404 deletions(-) create mode 100644 ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp create mode 100644 ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp new file mode 100644 index 00000000000..99acb5e6964 --- /dev/null +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -0,0 +1,70 @@ +#ifndef GGML_WEBGPU_SHADER_LIB_HPP +#define GGML_WEBGPU_SHADER_LIB_HPP + +#include "ggml-webgpu-structs.hpp" + +#include +#include + +extern const char * wgsl_flash_attn; + +webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}); + +static inline const char * ggml_webgpu_wgsl_kv_type(ggml_type type) { + switch (type) { + case GGML_TYPE_F16: + return "f16"; + case GGML_TYPE_F32: + return "f32"; + default: + return nullptr; + } +} + +inline webgpu_pipeline ggml_webgpu_get_flash_attn_pipeline(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst, + float logit_softcap) { + GGML_ASSERT(K->type == V->type); + + flash_attn_pipeline_key key = { + .q_type = Q->type, + .kv_type = K->type, + .mask_type = mask->type, + .sinks_type = sinks->type, + .dst_type = dst->type, + .head_dim_q = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .n_heads = (uint32_t) Q->ne[2], + .has_mask = true, + .has_sinks = true, + .uses_logit_softcap = logit_softcap != 0.0f, + }; + + auto it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + return it->second; + } + + std::lock_guard lock(ctx->mutex); + it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + return it->second; + } + + const char * kv_type = ggml_webgpu_wgsl_kv_type(K->type); + std::string label = std::string("flash_attn_kv_") + kv_type; + std::string shader = ctx->p.preprocess(wgsl_flash_attn, { std::string("KV_TYPE=") + kv_type }); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(ctx->device, shader.c_str(), label.c_str()); + ctx->flash_attn_pipelines.emplace(key, pipeline); + return pipeline; +} + +#endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp new file mode 100644 index 00000000000..029c557a14c --- /dev/null +++ b/ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp @@ -0,0 +1,303 @@ +#ifndef GGML_WEBGPU_STRUCTS_HPP +#define GGML_WEBGPU_STRUCTS_HPP + +#include "ggml.h" +#include "pre_wgsl.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label); + +struct webgpu_pool_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; +}; + +// The futures to wait on for a single queue submission +struct webgpu_submission_futures { + std::vector futures; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); + free.push_back({ host_buf, dev_buf }); + } + } + + webgpu_pool_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_pool_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + } + free.clear(); + } +}; + +#ifdef GGML_WEBGPU_GPU_PROFILE +struct webgpu_gpu_profile_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + wgpu::QuerySet query_set; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_gpu_profile_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); + // Create a query set for 2 timestamps + wgpu::QuerySetDescriptor ts_query_set_desc = {}; + + ts_query_set_desc.type = wgpu::QueryType::Timestamp; + ts_query_set_desc.count = 2; + wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); + + free.push_back({ host_buf, dev_buf, ts_query_set }); + } + } + + webgpu_gpu_profile_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_gpu_profile_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + bufs.query_set.Destroy(); + } + free.clear(); + } +}; +#endif + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; +}; + +struct webgpu_command { + wgpu::CommandBuffer commands; + webgpu_pool_bufs params_bufs; + std::optional set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_gpu_profile_bufs timestamp_query_bufs; + std::string pipeline_name; +#endif +}; + +struct flash_attn_pipeline_key { + int q_type; + int kv_type; + int mask_type; + int sinks_type; + int dst_type; + uint32_t head_dim_q; + uint32_t head_dim_v; + uint32_t n_heads; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + + bool operator==(const flash_attn_pipeline_key & other) const { + return q_type == other.q_type && kv_type == other.kv_type && mask_type == other.mask_type && + sinks_type == other.sinks_type && dst_type == other.dst_type && head_dim_q == other.head_dim_q && + head_dim_v == other.head_dim_v && n_heads == other.n_heads && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; + } +}; + +struct flash_attn_pipeline_key_hash { + size_t operator()(const flash_attn_pipeline_key & key) const { + size_t seed = 0; + auto mix = [&seed](size_t value) { + seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); + }; + mix(std::hash{}(key.q_type)); + mix(std::hash{}(key.kv_type)); + mix(std::hash{}(key.mask_type)); + mix(std::hash{}(key.sinks_type)); + mix(std::hash{}(key.dst_type)); + mix(std::hash{}(key.head_dim_q)); + mix(std::hash{}(key.head_dim_v)); + mix(std::hash{}(key.n_heads)); + mix(std::hash{}(key.has_mask)); + mix(std::hash{}(key.has_sinks)); + mix(std::hash{}(key.uses_logit_softcap)); + return seed; + } +}; + +// All the base objects needed to run operations on a WebGPU device +struct webgpu_context_struct { + wgpu::Instance instance; + wgpu::Adapter adapter; + wgpu::Device device; + wgpu::Queue queue; + wgpu::Limits limits; + + uint32_t subgroup_size; + +#ifndef __EMSCRIPTEN__ + bool supports_subgroup_matrix = false; + wgpu::SubgroupMatrixConfig subgroup_matrix_config; +#endif + + std::recursive_mutex mutex; + std::atomic_uint inflight_threads = 0; + + webgpu_buf_pool param_buf_pool; + webgpu_buf_pool set_rows_error_buf_pool; + + pre_wgsl::Preprocessor p; + + std::map memset_pipelines; // variant or type index + + std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::map>> + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + + std::unordered_map flash_attn_pipelines; + + std::map> set_rows_pipelines; // dst_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized + + std::map> cpy_pipelines; // src_type, dst_type + std::map> add_pipelines; // type, inplace + std::map> sub_pipelines; // type, inplace + std::map> mul_pipelines; // type, inplace + std::map> div_pipelines; // type, inplace + + std::map rms_norm_pipelines; // inplace + std::map>> rope_pipelines; // type, ff, inplace + std::map>> glu_pipelines; // glu_op, type, split + std::map scale_pipelines; // inplace + std::map>> soft_max_pipelines; // mask_type, has_sink, inplace + std::map>> unary_pipelines; // unary_op, type, inplace + + size_t memset_bytes_per_thread; + + // Staging buffer for reading data from the GPU + wgpu::Buffer get_tensor_staging_buf; + +#ifdef GGML_WEBGPU_DEBUG + wgpu::Buffer debug_host_buf; + wgpu::Buffer debug_dev_buf; +#endif + +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map cpu_detail_ms; +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + // Profiling: pool of timestamp query buffers (one per operation) + webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; +#endif +}; + +using webgpu_context = std::shared_ptr; + +struct ggml_backend_webgpu_reg_context { + webgpu_context webgpu_ctx; + size_t device_count; + const char * name; +}; + +struct ggml_backend_webgpu_device_context { + webgpu_context webgpu_ctx; + std::string device_name; + std::string device_desc; +}; + +struct ggml_backend_webgpu_context { + webgpu_context webgpu_ctx; + std::string name; +}; + +struct ggml_backend_webgpu_buffer_context { + webgpu_context webgpu_ctx; + wgpu::Buffer buffer; + std::string label; + + ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : + webgpu_ctx(std::move(ctx)), + buffer(std::move(buf)), + label(std::move(lbl)) {} +}; + +#endif // GGML_WEBGPU_STRUCTS_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 286413894f6..c46921ed0f2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -7,8 +7,9 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" +#include "ggml-webgpu-shader-lib.hpp" +#include "ggml-webgpu-structs.hpp" #include "ggml-wgsl-shaders.hpp" -#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -129,321 +130,13 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; } -/* Struct definitions */ - -// Forward reference -static void ggml_webgpu_create_buffer(wgpu::Device & device, - wgpu::Buffer & buffer, - size_t size, - wgpu::BufferUsage usage, - const char * label); - -struct webgpu_pool_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; -}; - -// The futures to wait on for a single queue submission -struct webgpu_submission_futures { - std::vector futures; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back({ host_buf, dev_buf }); - } - } - - webgpu_pool_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_pool_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - } - free.clear(); - } -}; - -#ifdef GGML_WEBGPU_GPU_PROFILE -struct webgpu_gpu_profile_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - wgpu::QuerySet query_set; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_gpu_profile_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); - // Create a query set for 2 timestamps - wgpu::QuerySetDescriptor ts_query_set_desc = {}; - - ts_query_set_desc.type = wgpu::QueryType::Timestamp; - ts_query_set_desc.count = 2; - wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); - - free.push_back({ host_buf, dev_buf, ts_query_set }); - } - } - - webgpu_gpu_profile_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_gpu_profile_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - bufs.query_set.Destroy(); - } - free.clear(); - } -}; -#endif - -struct webgpu_pipeline { - wgpu::ComputePipeline pipeline; - std::string name; -}; - -struct webgpu_command { - wgpu::CommandBuffer commands; - webgpu_pool_bufs params_bufs; - std::optional set_rows_error_bufs; -#ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs timestamp_query_bufs; - std::string pipeline_name; -#endif -}; - -// Pipeline keys - -template -static inline void webgpu_hash_combine(size_t & seed, const T & value) { - seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - -static inline const char * ggml_webgpu_wgsl_kv_type(ggml_type type) { - switch (type) { - case GGML_TYPE_F16: return "f16"; - case GGML_TYPE_F32: return "f32"; - default: return nullptr; - } -} - -struct flash_attn_pipeline_key { - int q_type; - int kv_type; - int mask_type; - int sinks_type; - int dst_type; - uint32_t head_dim_q; - uint32_t head_dim_v; - uint32_t n_heads; - bool has_mask; - bool has_sinks; - bool uses_logit_softcap; - - bool operator==(const flash_attn_pipeline_key & other) const { - return q_type == other.q_type && - kv_type == other.kv_type && - mask_type == other.mask_type && - sinks_type == other.sinks_type && - dst_type == other.dst_type && - head_dim_q == other.head_dim_q && - head_dim_v == other.head_dim_v && - n_heads == other.n_heads && - has_mask == other.has_mask && - has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; - } -}; - -struct flash_attn_pipeline_key_hash { - size_t operator()(const flash_attn_pipeline_key & key) const { - size_t seed = 0; - webgpu_hash_combine(seed, key.q_type); - webgpu_hash_combine(seed, key.kv_type); - webgpu_hash_combine(seed, key.mask_type); - webgpu_hash_combine(seed, key.sinks_type); - webgpu_hash_combine(seed, key.dst_type); - webgpu_hash_combine(seed, key.head_dim_q); - webgpu_hash_combine(seed, key.head_dim_v); - webgpu_hash_combine(seed, key.n_heads); - webgpu_hash_combine(seed, key.has_mask); - webgpu_hash_combine(seed, key.has_sinks); - webgpu_hash_combine(seed, key.uses_logit_softcap); - return seed; - } -}; - -// All the base objects needed to run operations on a WebGPU device -struct webgpu_context_struct { - wgpu::Instance instance; - wgpu::Adapter adapter; - wgpu::Device device; - wgpu::Queue queue; - wgpu::Limits limits; - - uint32_t subgroup_size; - -#ifndef __EMSCRIPTEN__ - bool supports_subgroup_matrix = false; - wgpu::SubgroupMatrixConfig subgroup_matrix_config; -#endif - - std::recursive_mutex mutex; - std::atomic_uint inflight_threads = 0; - - webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; - - pre_wgsl::Preprocessor p; - - std::map memset_pipelines; // variant or type index - - std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - - std::unordered_map - flash_attn_pipelines; - - std::map> set_rows_pipelines; // dst_type, vectorized - std::map> get_rows_pipelines; // src_type, vectorized - - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace - - std::map rms_norm_pipelines; // inplace - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split - std::map scale_pipelines; // inplace - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - std::map>> unary_pipelines; // unary_op, type, inplace - - size_t memset_bytes_per_thread; - - // Staging buffer for reading data from the GPU - wgpu::Buffer get_tensor_staging_buf; - -#ifdef GGML_WEBGPU_DEBUG - wgpu::Buffer debug_host_buf; - wgpu::Buffer debug_dev_buf; -#endif - -#ifdef GGML_WEBGPU_CPU_PROFILE - // Profiling: labeled CPU time in ms (total) - std::unordered_map cpu_time_ms; - // Profiling: detailed CPU time in ms - std::unordered_map cpu_detail_ms; -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; -#endif -}; - -typedef std::shared_ptr webgpu_context; - -struct ggml_backend_webgpu_reg_context { - webgpu_context webgpu_ctx; - size_t device_count; - const char * name; -}; - -struct ggml_backend_webgpu_device_context { - webgpu_context webgpu_ctx; - std::string device_name; - std::string device_desc; -}; - -struct ggml_backend_webgpu_context { - webgpu_context webgpu_ctx; - std::string name; -}; - -struct ggml_backend_webgpu_buffer_context { - webgpu_context webgpu_ctx; - wgpu::Buffer buffer; - std::string label; - - ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : - webgpu_ctx(std::move(ctx)), - buffer(std::move(buf)), - label(std::move(lbl)) {} -}; - -/* End struct definitions */ - /* WebGPU object initializations */ -static void ggml_webgpu_create_buffer(wgpu::Device & device, - wgpu::Buffer & buffer, - size_t size, - wgpu::BufferUsage usage, - const char * label) { +void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label) { wgpu::BufferDescriptor buffer_desc; buffer_desc.size = size; buffer_desc.usage = usage; @@ -473,10 +166,10 @@ static std::string ggml_webgpu_process_shader_repls(const char * return s; } -static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, - const char * shader_code, - const char * label, - const std::vector & constants = {}) { +webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants) { wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; @@ -497,47 +190,6 @@ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & return { device.CreateComputePipeline(&pipeline_desc), label }; } -static webgpu_pipeline ggml_webgpu_get_flash_attn_pipeline(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst, - float logit_softcap) { - GGML_ASSERT(K->type == V->type); - - flash_attn_pipeline_key key = { - .q_type = Q->type, - .kv_type = K->type, - .mask_type = mask->type, - .sinks_type = sinks->type, - .dst_type = dst->type, - .head_dim_q = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .n_heads = (uint32_t) Q->ne[2], - .has_mask = true, - .has_sinks = true, - .uses_logit_softcap = logit_softcap != 0.0f, - }; - - if (ctx->flash_attn_pipelines.count(key)) { - return ctx->flash_attn_pipelines[key]; - } - - std::lock_guard lock(ctx->mutex); - if (ctx->flash_attn_pipelines.count(key)) { - return ctx->flash_attn_pipelines[key]; - } - - const char * kv_type = ggml_webgpu_wgsl_kv_type(K->type); - std::string label = std::string("flash_attn_kv_") + kv_type; - std::string shader = ctx->p.preprocess(wgsl_flash_attn, { std::string("KV_TYPE=") + kv_type }); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(ctx->device, shader.c_str(), label.c_str()); - ctx->flash_attn_pipelines.emplace(key, pipeline); - return pipeline; -} - /** End WebGPU object initializations */ /** WebGPU Actions */ @@ -1104,7 +756,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * sinks, ggml_tensor * dst) { // For now we assume everything (mask, sink) - float scale = *(float *) dst->op_params; + float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); float logit_softcap; @@ -1116,19 +768,19 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - // print type and dimensions of Q/K/V/mask/sinks/dst -// std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] -// << ", " << Q->ne[3] << "]\n"; -// std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] -// << ", " << K->ne[3] << "]\n"; -// std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] -// << ", " << V->ne[3] << "]\n"; -// std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] -// << ", " << mask->ne[3] << "]\n"; -// std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] -// << ", " << sinks->ne[3] << "]\n"; -// std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] -// << ", " << dst->ne[3] << "]\n"; + // print type and dimensions of Q/K/V/mask/sinks/dst + // std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] + // << ", " << Q->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] + // << ", " << K->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] + // << ", " << V->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] + // << ", " << mask->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] + // << ", " << sinks->ne[3] << "]\n"; + // std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] + // << ", " << dst->ne[3] << "]\n"; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), @@ -1137,23 +789,23 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) Q->ne[0], // head dimension (Q/K) - (uint32_t) V->ne[0], // head dimension (V) - (uint32_t) Q->ne[2], // number of heads - (uint32_t) Q->ne[1], // sequence length (Q) - (uint32_t) K->ne[1], // sequence length (K/V) - (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 - (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 - (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 - (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 - (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 - (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 - (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 - (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 - (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)), // stride of mask dim 3 - (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) + (uint32_t) Q->ne[0], // head dimension (Q/K) + (uint32_t) V->ne[0], // head dimension (V) + (uint32_t) Q->ne[2], // number of heads + (uint32_t) Q->ne[1], // sequence length (Q) + (uint32_t) K->ne[1], // sequence length (K/V) + (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 + (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 + (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 + (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 + (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 + (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 + (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 + (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 + (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 + (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)), // stride of mask dim 3 + (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) + *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) *(uint32_t *) &max_bias, *(uint32_t *) &logit_softcap, *(uint32_t *) &n_head_log2, @@ -2699,20 +2351,20 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_FLASH_ATTN_EXT: - { - supports_op = true; - // Q-type - supports_op &= src0->type == GGML_TYPE_F32; - // KV-type - supports_op &= src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16; - // Mask-type - supports_op &= op->src[3] != nullptr; - // Sink-type - supports_op &= op->src[4] != nullptr; - // qkv sequence length - supports_op &= op->ne[0] <= 128 && src0->ne[0] <= 128; - break; - } + { + supports_op = true; + // Q-type + supports_op &= src0->type == GGML_TYPE_F32; + // KV-type + supports_op &= src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16; + // Mask-type + supports_op &= op->src[3] != nullptr; + // Sink-type + supports_op &= op->src[4] != nullptr; + // qkv sequence length + supports_op &= op->ne[0] <= 128 && src0->ne[0] <= 128; + break; + } case GGML_OP_RMS_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; From b072b4b99aef049a94fe319fcce47d00bfbf1d71 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 30 Dec 2025 08:40:26 -0800 Subject: [PATCH 32/40] Work on compilation choices for flashattention --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 105 ++--- ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp | 303 ------------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 425 ++++++++++++++++-- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 67 ++- 4 files changed, 478 insertions(+), 422 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 99acb5e6964..968737b5317 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,70 +1,71 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP -#include "ggml-webgpu-structs.hpp" +#include "pre_wgsl.hpp" #include #include -extern const char * wgsl_flash_attn; +struct ggml_webgpu_flash_attn_shader_lib_context { + const char * kv_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; +}; -webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, - const char * shader_code, - const char * label, - const std::vector & constants = {}); +struct ggml_webgpu_flash_attn_shader_decisions { + int unused = 0; +}; -static inline const char * ggml_webgpu_wgsl_kv_type(ggml_type type) { - switch (type) { - case GGML_TYPE_F16: - return "f16"; - case GGML_TYPE_F32: - return "f32"; - default: - return nullptr; - } -} +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + ggml_webgpu_flash_attn_shader_decisions decisions; +}; -inline webgpu_pipeline ggml_webgpu_get_flash_attn_pipeline(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst, - float logit_softcap) { - GGML_ASSERT(K->type == V->type); +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn"; - flash_attn_pipeline_key key = { - .q_type = Q->type, - .kv_type = K->type, - .mask_type = mask->type, - .sinks_type = sinks->type, - .dst_type = dst->type, - .head_dim_q = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .n_heads = (uint32_t) Q->ne[2], - .has_mask = true, - .has_sinks = true, - .uses_logit_softcap = logit_softcap != 0.0f, - }; + defines.push_back(std::string("KV_TYPE=") + context.kv_type); + variant += std::string("_") + context.kv_type; - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - return it->second; + if (context.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; } - - std::lock_guard lock(ctx->mutex); - it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - return it->second; + if (context.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (context.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; } - const char * kv_type = ggml_webgpu_wgsl_kv_type(K->type); - std::string label = std::string("flash_attn_kv_") + kv_type; - std::string shader = ctx->p.preprocess(wgsl_flash_attn, { std::string("KV_TYPE=") + kv_type }); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(ctx->device, shader.c_str(), label.c_str()); - ctx->flash_attn_pipelines.emplace(key, pipeline); - return pipeline; + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.head_dim_v); + + // For now these are not part of the variant name + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; } #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp deleted file mode 100644 index 029c557a14c..00000000000 --- a/ggml/src/ggml-webgpu/ggml-webgpu-structs.hpp +++ /dev/null @@ -1,303 +0,0 @@ -#ifndef GGML_WEBGPU_STRUCTS_HPP -#define GGML_WEBGPU_STRUCTS_HPP - -#include "ggml.h" -#include "pre_wgsl.hpp" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -void ggml_webgpu_create_buffer(wgpu::Device & device, - wgpu::Buffer & buffer, - size_t size, - wgpu::BufferUsage usage, - const char * label); - -struct webgpu_pool_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; -}; - -// The futures to wait on for a single queue submission -struct webgpu_submission_futures { - std::vector futures; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back({ host_buf, dev_buf }); - } - } - - webgpu_pool_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_pool_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - } - free.clear(); - } -}; - -#ifdef GGML_WEBGPU_GPU_PROFILE -struct webgpu_gpu_profile_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - wgpu::QuerySet query_set; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_gpu_profile_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); - // Create a query set for 2 timestamps - wgpu::QuerySetDescriptor ts_query_set_desc = {}; - - ts_query_set_desc.type = wgpu::QueryType::Timestamp; - ts_query_set_desc.count = 2; - wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); - - free.push_back({ host_buf, dev_buf, ts_query_set }); - } - } - - webgpu_gpu_profile_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_gpu_profile_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - bufs.query_set.Destroy(); - } - free.clear(); - } -}; -#endif - -struct webgpu_pipeline { - wgpu::ComputePipeline pipeline; - std::string name; -}; - -struct webgpu_command { - wgpu::CommandBuffer commands; - webgpu_pool_bufs params_bufs; - std::optional set_rows_error_bufs; -#ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs timestamp_query_bufs; - std::string pipeline_name; -#endif -}; - -struct flash_attn_pipeline_key { - int q_type; - int kv_type; - int mask_type; - int sinks_type; - int dst_type; - uint32_t head_dim_q; - uint32_t head_dim_v; - uint32_t n_heads; - bool has_mask; - bool has_sinks; - bool uses_logit_softcap; - - bool operator==(const flash_attn_pipeline_key & other) const { - return q_type == other.q_type && kv_type == other.kv_type && mask_type == other.mask_type && - sinks_type == other.sinks_type && dst_type == other.dst_type && head_dim_q == other.head_dim_q && - head_dim_v == other.head_dim_v && n_heads == other.n_heads && has_mask == other.has_mask && - has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; - } -}; - -struct flash_attn_pipeline_key_hash { - size_t operator()(const flash_attn_pipeline_key & key) const { - size_t seed = 0; - auto mix = [&seed](size_t value) { - seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); - }; - mix(std::hash{}(key.q_type)); - mix(std::hash{}(key.kv_type)); - mix(std::hash{}(key.mask_type)); - mix(std::hash{}(key.sinks_type)); - mix(std::hash{}(key.dst_type)); - mix(std::hash{}(key.head_dim_q)); - mix(std::hash{}(key.head_dim_v)); - mix(std::hash{}(key.n_heads)); - mix(std::hash{}(key.has_mask)); - mix(std::hash{}(key.has_sinks)); - mix(std::hash{}(key.uses_logit_softcap)); - return seed; - } -}; - -// All the base objects needed to run operations on a WebGPU device -struct webgpu_context_struct { - wgpu::Instance instance; - wgpu::Adapter adapter; - wgpu::Device device; - wgpu::Queue queue; - wgpu::Limits limits; - - uint32_t subgroup_size; - -#ifndef __EMSCRIPTEN__ - bool supports_subgroup_matrix = false; - wgpu::SubgroupMatrixConfig subgroup_matrix_config; -#endif - - std::recursive_mutex mutex; - std::atomic_uint inflight_threads = 0; - - webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; - - pre_wgsl::Preprocessor p; - - std::map memset_pipelines; // variant or type index - - std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - - std::unordered_map flash_attn_pipelines; - - std::map> set_rows_pipelines; // dst_type, vectorized - std::map> get_rows_pipelines; // src_type, vectorized - - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace - - std::map rms_norm_pipelines; // inplace - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split - std::map scale_pipelines; // inplace - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - std::map>> unary_pipelines; // unary_op, type, inplace - - size_t memset_bytes_per_thread; - - // Staging buffer for reading data from the GPU - wgpu::Buffer get_tensor_staging_buf; - -#ifdef GGML_WEBGPU_DEBUG - wgpu::Buffer debug_host_buf; - wgpu::Buffer debug_dev_buf; -#endif - -#ifdef GGML_WEBGPU_CPU_PROFILE - // Profiling: labeled CPU time in ms (total) - std::unordered_map cpu_time_ms; - // Profiling: detailed CPU time in ms - std::unordered_map cpu_detail_ms; -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; -#endif -}; - -using webgpu_context = std::shared_ptr; - -struct ggml_backend_webgpu_reg_context { - webgpu_context webgpu_ctx; - size_t device_count; - const char * name; -}; - -struct ggml_backend_webgpu_device_context { - webgpu_context webgpu_ctx; - std::string device_name; - std::string device_desc; -}; - -struct ggml_backend_webgpu_context { - webgpu_context webgpu_ctx; - std::string name; -}; - -struct ggml_backend_webgpu_buffer_context { - webgpu_context webgpu_ctx; - wgpu::Buffer buffer; - std::string label; - - ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : - webgpu_ctx(std::move(ctx)), - buffer(std::move(buf)), - label(std::move(lbl)) {} -}; - -#endif // GGML_WEBGPU_STRUCTS_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c46921ed0f2..32f9afade4b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,8 +8,8 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "ggml-webgpu-structs.hpp" #include "ggml-wgsl-shaders.hpp" +#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -119,6 +119,286 @@ /* End Constants */ +void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label); + +struct webgpu_pool_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; +}; + +// The futures to wait on for a single queue submission +struct webgpu_submission_futures { + std::vector futures; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); + free.push_back({ host_buf, dev_buf }); + } + } + + webgpu_pool_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_pool_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + } + free.clear(); + } +}; + +#ifdef GGML_WEBGPU_GPU_PROFILE +struct webgpu_gpu_profile_bufs { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + wgpu::QuerySet query_set; +}; + +// Holds a pool of parameter buffers for WebGPU operations +struct webgpu_gpu_profile_buf_pool { + std::vector free; + + std::mutex mutex; + + std::condition_variable cv; + + void init(wgpu::Device device, + int num_bufs, + size_t buf_size, + wgpu::BufferUsage dev_buf_usage, + wgpu::BufferUsage host_buf_usage) { + for (int i = 0; i < num_bufs; i++) { + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); + // Create a query set for 2 timestamps + wgpu::QuerySetDescriptor ts_query_set_desc = {}; + + ts_query_set_desc.type = wgpu::QueryType::Timestamp; + ts_query_set_desc.count = 2; + wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); + + free.push_back({ host_buf, dev_buf, ts_query_set }); + } + } + + webgpu_gpu_profile_bufs alloc_bufs() { + std::unique_lock lock(mutex); + cv.wait(lock, [this] { return !free.empty(); }); + webgpu_gpu_profile_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + void free_bufs(std::vector bufs) { + std::lock_guard lock(mutex); + free.insert(free.end(), bufs.begin(), bufs.end()); + cv.notify_all(); + } + + void cleanup() { + std::lock_guard lock(mutex); + for (auto & bufs : free) { + bufs.host_buf.Destroy(); + bufs.dev_buf.Destroy(); + bufs.query_set.Destroy(); + } + free.clear(); + } +}; +#endif + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; +}; + +struct webgpu_command { + wgpu::CommandBuffer commands; + webgpu_pool_bufs params_bufs; + std::optional set_rows_error_bufs; +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_gpu_profile_bufs timestamp_query_bufs; + std::string pipeline_name; +#endif +}; + +struct flash_attn_pipeline_key { + int q_type; + int kv_type; + int dst_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + uint32_t n_heads; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + + bool operator==(const flash_attn_pipeline_key & other) const { + return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && n_heads == other.n_heads && + has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap; + } +}; + +template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { + seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +struct flash_attn_pipeline_key_hash { + size_t operator()(const flash_attn_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.n_heads); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + return seed; + } +}; + +// All the base objects needed to run operations on a WebGPU device +struct webgpu_context_struct { + wgpu::Instance instance; + wgpu::Adapter adapter; + wgpu::Device device; + wgpu::Queue queue; + wgpu::Limits limits; + + uint32_t subgroup_size; + +#ifndef __EMSCRIPTEN__ + bool supports_subgroup_matrix = false; + wgpu::SubgroupMatrixConfig subgroup_matrix_config; +#endif + + std::recursive_mutex mutex; + std::atomic_uint inflight_threads = 0; + + webgpu_buf_pool param_buf_pool; + webgpu_buf_pool set_rows_error_buf_pool; + + pre_wgsl::Preprocessor p; + + std::map memset_pipelines; // variant or type index + + std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized + std::map>> + mul_mat_vec_pipelines; // src0_type, src1_type, vectorized + + std::unordered_map flash_attn_pipelines; + + std::map> set_rows_pipelines; // dst_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized + + std::map> cpy_pipelines; // src_type, dst_type + std::map> add_pipelines; // type, inplace + std::map> sub_pipelines; // type, inplace + std::map> mul_pipelines; // type, inplace + std::map> div_pipelines; // type, inplace + + std::map rms_norm_pipelines; // inplace + std::map>> rope_pipelines; // type, ff, inplace + std::map>> glu_pipelines; // glu_op, type, split + std::map scale_pipelines; // inplace + std::map>> soft_max_pipelines; // mask_type, has_sink, inplace + std::map>> unary_pipelines; // unary_op, type, inplace + + size_t memset_bytes_per_thread; + + // Staging buffer for reading data from the GPU + wgpu::Buffer get_tensor_staging_buf; + +#ifdef GGML_WEBGPU_DEBUG + wgpu::Buffer debug_host_buf; + wgpu::Buffer debug_dev_buf; +#endif + +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map cpu_detail_ms; +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + // Profiling: pool of timestamp query buffers (one per operation) + webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; +#endif +}; + +using webgpu_context = std::shared_ptr; + +struct ggml_backend_webgpu_reg_context { + webgpu_context webgpu_ctx; + size_t device_count; + const char * name; +}; + +struct ggml_backend_webgpu_device_context { + webgpu_context webgpu_ctx; + std::string device_name; + std::string device_desc; +}; + +struct ggml_backend_webgpu_context { + webgpu_context webgpu_ctx; + std::string name; +}; + +struct ggml_backend_webgpu_buffer_context { + webgpu_context webgpu_ctx; + wgpu::Buffer buffer; + std::string label; + + ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : + webgpu_ctx(std::move(ctx)), + buffer(std::move(buf)), + label(std::move(lbl)) {} +}; + // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT @@ -166,10 +446,10 @@ static std::string ggml_webgpu_process_shader_repls(const char * return s; } -webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, - const char * shader_code, - const char * label, - const std::vector & constants) { +static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + const char * shader_code, + const char * label, + const std::vector & constants = {}) { wgpu::ShaderSourceWGSL shader_source; shader_source.code = shader_code; @@ -755,7 +1035,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * mask, ggml_tensor * sinks, ggml_tensor * dst) { - // For now we assume everything (mask, sink) float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); @@ -768,6 +1047,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int has_mask = (mask != nullptr); + const int has_sinks = (sinks != nullptr); + // print type and dimensions of Q/K/V/mask/sinks/dst // std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] // << ", " << Q->ne[3] << "]\n"; @@ -786,26 +1068,26 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)), + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) Q->ne[0], // head dimension (Q/K) - (uint32_t) V->ne[0], // head dimension (V) - (uint32_t) Q->ne[2], // number of heads - (uint32_t) Q->ne[1], // sequence length (Q) - (uint32_t) K->ne[1], // sequence length (K/V) - (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 - (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 - (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 - (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 - (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 - (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 - (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 - (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 - (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)), // stride of mask dim 3 - (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) + (uint32_t) Q->ne[0], // head dimension (Q/K) + (uint32_t) V->ne[0], // head dimension (V) + (uint32_t) Q->ne[2], // number of heads + (uint32_t) Q->ne[1], // sequence length (Q) + (uint32_t) K->ne[1], // sequence length (K/V) + (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 1 + (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 2 + (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)), // stride (elements/blocks) of Q in dimension 3 + (uint32_t) (K->nb[1] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 1 + (uint32_t) (K->nb[2] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 2 + (uint32_t) (K->nb[3] / ggml_type_size(K->type)), // stride (elements/blocks) of K in dimension 3 + (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 + (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 + (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 + has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) + *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) *(uint32_t *) &max_bias, *(uint32_t *) &logit_softcap, *(uint32_t *) &n_head_log2, @@ -817,34 +1099,83 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, { .binding = 0, .buffer = ggml_webgpu_tensor_buf(Q), .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, { .binding = 1, .buffer = ggml_webgpu_tensor_buf(K), .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, { .binding = 2, .buffer = ggml_webgpu_tensor_buf(V), .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - // { .binding = 6, - // .buffer = ctx->debug_dev_buf, - // .offset = 0, - // .size = ctx->debug_dev_buf.GetSize() } + .size = ggml_webgpu_tensor_binding_size(ctx, V) } + }; + uint binding_index = 3; + if (has_mask) { + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + entries.push_back({ .binding = binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + // Debug buffer binding (for development only) + // entries.push_back( + // { .binding = binding_index, + // .buffer = ctx->debug_dev_buf, + // .offset = 0, + // .size = ctx->debug_dev_buf.GetSize() }); + + GGML_ASSERT(K->type == V->type); + + flash_attn_pipeline_key key = { + .q_type = Q->type, + .kv_type = K->type, + .dst_type = dst->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .n_heads = (uint32_t) Q->ne[2], + .has_mask = mask != nullptr, + .has_sinks = sinks != nullptr, + .uses_logit_softcap = logit_softcap != 0.0f, }; - webgpu_pipeline pipeline = ggml_webgpu_get_flash_attn_pipeline(ctx, Q, K, V, mask, sinks, dst, logit_softcap); + webgpu_pipeline pipeline; + auto it = ctx->flash_attn_pipelines.find(key); + + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + } else { + std::lock_guard lock(ctx->mutex); + it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .kv_type = ggml_type_name(K->type), + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .has_mask = mask != nullptr, + .has_sinks = sinks != nullptr, + .uses_logit_softcap = logit_softcap != 0.0f, + .sg_mat_m = ctx->subgroup_matrix_config.M, + .sg_mat_n = ctx->subgroup_matrix_config.N, + .sg_mat_k = ctx->subgroup_matrix_config.K, + }; + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); + + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->flash_attn_pipelines.emplace(key, pipeline); + } + } uint32_t wg_per_head = CEIL_DIV(Q->ne[1], WEBGPU_FLASH_ATTN_Q_TILE); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches @@ -2357,10 +2688,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op &= src0->type == GGML_TYPE_F32; // KV-type supports_op &= src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16; - // Mask-type - supports_op &= op->src[3] != nullptr; - // Sink-type - supports_op &= op->src[4] != nullptr; // qkv sequence length supports_op &= op->ne[0] <= 128 && src0->ne[0] <= 128; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 2fe434b1b66..edf42dc02d2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -6,6 +6,17 @@ enable chromium_experimental_subgroup_matrix; // Default values #define KV_TYPE f32 +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +#define Q_TILE 16 +#define KV_TILE 16 + +// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN +// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 struct Params { offset_q: u32, @@ -49,26 +60,44 @@ struct Params { @group(0) @binding(0) var Q: array; @group(0) @binding(1) var K: array; @group(0) @binding(2) var V: array; + +#if defined(MASK) && defined(SINKS) @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; -@group(0) @binding(5) var dst: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif + +@group(0) @binding(DST_BINDING) var dst: array; //@group(0) @binding(6) var debug: array; -@group(0) @binding(6) var params: Params; +@group(0) @binding(PARAMS_BINDING) var params: Params; const FLOAT_MIN: f16 = -65504.0; // The number of Q rows processed per workgroup -const Q_TILE = 16u; -var q_shmem: array; // assumes max head_dim_qk of 128 +var q_shmem: array; -const KV_TILE = 16u; -// we can reuse the same shmem for K and V since we only need one at a time right? -var kv_shmem: array; // assuming max head_dim_v of 128 +// we can reuse the same shmem for K and V since we only need one at a time +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +var kv_shmem: array; -var o_shmem: array; // output shmem +var o_shmem: array; // output shmem +#ifdef MASK // storage for mask values var mask_shmem: array; +#endif // storage for output of Q*K^T scores for online softmax (S matrix from paper) // also storage for diagonal matrix during online softmax (P matrix from paper) @@ -79,12 +108,6 @@ var inter_shmem: array; var row_max_shmem: array; var exp_sum_shmem: array; -// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN -// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. -const SG_MAT_M = 8u; -const SG_MAT_N = 8u; -const SG_MAT_K = 8u; - // Number of blocks this workgroup handles at the subgroup matrix level. SG_MAT_M must divide Q_TILE. const Q_BLOCKS = Q_TILE / SG_MAT_M; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. @@ -134,8 +157,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_in_head = wg_in_batch % wg_per_head; let q_row_start = wg_in_head * Q_TILE; +#ifdef MASK // mask offset let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * params.head_dim_v; @@ -202,6 +227,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } +#ifdef MASK // load mask tile into shared memory for this KV block // TODO: optimize and skip if mask is -INF for the entire tile for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { @@ -214,6 +240,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); } +#endif workgroupBarrier(); @@ -233,14 +260,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // TODO: is this faster than having all threads read shared memory? var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); prev_max = subgroupBroadcastFirst(prev_max); + var thread_tile_row_max = select(FLOAT_MIN, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * f16(params.scale), sg_inv_id < KV_TILE); +#ifdef LOGIT_SOFTCAP + thread_tile_row_max = f16(params.logit_softcap) * tanh(thread_tile_row_max); +#endif +#ifdef MASK // The mask value for this Q row and K col let mask_val = select(0.0, mask_shmem[q_tile_row * KV_TILE + sg_inv_id], sg_inv_id < KV_TILE); let mask_term = slope * mask_val; - var thread_tile_row_max = select(FLOAT_MIN, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * f16(params.scale), sg_inv_id < KV_TILE); - if (params.logit_softcap != 0.0) { - thread_tile_row_max = f16(params.logit_softcap) * tanh(thread_tile_row_max); - } thread_tile_row_max += mask_term; +#endif let new_max = subgroupMax(max(prev_max, thread_tile_row_max)); // calculate running exp sum @@ -323,6 +352,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); } +#ifdef SINKS // add sinks (applied once after processing all KV tiles) for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { let block_row_start = sg_block * SG_MAT_M; @@ -358,6 +388,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } workgroupBarrier(); +#endif // write output back to global memory for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { From 7886418b556c9045855bde3b48da6da375bc6236 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 30 Dec 2025 18:50:53 -0800 Subject: [PATCH 33/40] Work on subgroup matrix/tile size portability --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 91 ++++++++++++++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 44 ++++++--- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 3 +- 3 files changed, 122 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 968737b5317..6654c4e4d6f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -6,6 +6,9 @@ #include #include +#define GGML_WEBGPU_F16_SIZE_BYTES 2 +#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 64u + struct ggml_webgpu_flash_attn_shader_lib_context { const char * kv_type; uint32_t head_dim_qk; @@ -16,10 +19,14 @@ struct ggml_webgpu_flash_attn_shader_lib_context { uint32_t sg_mat_m; uint32_t sg_mat_n; uint32_t sg_mat_k; + size_t wg_mem_limit_bytes; + uint32_t max_subgroup_size; }; struct ggml_webgpu_flash_attn_shader_decisions { - int unused = 0; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; }; struct ggml_webgpu_processed_shader { @@ -28,6 +35,71 @@ struct ggml_webgpu_processed_shader { ggml_webgpu_flash_attn_shader_decisions decisions; }; +// This is exposed because it's necessary in supports_op +inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, + uint32_t kv_tile, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask) { + const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); + size_t elems = 0; + elems += q_tile * head_dim_qk; // q_shmem + elems += kv_tile * max_head_dim; // kv_shmem + elems += q_tile * head_dim_v; // o_shmem + if (has_mask) { + elems += q_tile * kv_tile; // mask_shmem + } + elems += q_tile * kv_tile; // inter_shmem + elems += q_tile; // row_max_shmem + elems += q_tile; // exp_sum_shmem + return elems * GGML_WEBGPU_F16_SIZE_BYTES; +} + +// Returns a pair of (q_tile, kv_tile) that best fits within the workgroup memory limit +// Currently set to prefer the configuration that comes closest to using half of the limit +// Assumes that the base minimum tile sizes fits within the limit +static std::pair ggml_webgpu_flash_attn_tile_sizes( + const ggml_webgpu_flash_attn_shader_lib_context & context) { + std::pair best_pair = { 0, 0 }; + size_t best_delta = 0; + + const uint32_t min_q_tile = context.sg_mat_m; + const uint32_t min_kv_tile = context.sg_mat_n; + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t target_bytes = limit_bytes / 2; + const uint32_t max_head_dim = std::max(context.head_dim_qk, context.head_dim_v); + + // These sizes come from the equations for wg_mem_bytes, solving for q_tile or kv_tile respectively + const size_t base_kv_bytes = min_kv_tile * max_head_dim * GGML_WEBGPU_F16_SIZE_BYTES; + const size_t bytes_per_q = + (context.head_dim_qk + context.head_dim_v + (context.has_mask ? min_kv_tile : 0) + min_kv_tile + 2) * + GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_q_tile = (limit_bytes - base_kv_bytes) / bytes_per_q; + + const size_t base_q_bytes = + (context.head_dim_qk + context.head_dim_v + 2) * min_q_tile * GGML_WEBGPU_F16_SIZE_BYTES; + const size_t bytes_per_kv = + (max_head_dim + (context.has_mask ? min_q_tile : 0) + min_q_tile) * GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + + // step by minimum tile sizes + for (uint32_t q = min_q_tile; q <= max_q_tile; q += min_q_tile) { + for (uint32_t kv = min_kv_tile; kv <= max_kv_tile; kv += min_kv_tile) { + size_t bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q, kv, context.head_dim_qk, context.head_dim_v, context.has_mask); + if (bytes <= limit_bytes) { + size_t delta = bytes > target_bytes ? bytes - target_bytes : target_bytes - bytes; + if (best_pair.first == 0 || delta < best_delta) { + best_pair = { q, kv }; + best_delta = delta; + } + } + } + } + + return best_pair; +} + inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( pre_wgsl::Preprocessor & preprocessor, const char * shader_src, @@ -62,9 +134,22 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + // Add chosen Q/KV tile sizes + auto [q_tile, kv_tile] = ggml_webgpu_flash_attn_tile_sizes(context); + kv_tile = std::min(kv_tile, 32); // TODO: temporary cap to 32 for testing + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + // workgroup size + uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + result.decisions.q_tile = q_tile; + result.decisions.kv_tile = kv_tile; + result.decisions.wg_size = wg_size; return result; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 32f9afade4b..e506e0e21d2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -112,11 +112,6 @@ #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 #define WEBGPU_MUL_MAT_VEC_TILE_K 256 -// Flash Attention parameters -#define WEBGPU_FLASH_ATTN_WG_SIZE 32 -#define WEBGPU_FLASH_ATTN_Q_TILE 16 -#define WEBGPU_FLASH_ATTN_KV_TILE 8 - /* End Constants */ void ggml_webgpu_create_buffer(wgpu::Device & device, @@ -246,6 +241,7 @@ struct webgpu_gpu_profile_buf_pool { struct webgpu_pipeline { wgpu::ComputePipeline pipeline; std::string name; + void * user_data = nullptr; }; struct webgpu_command { @@ -305,7 +301,7 @@ struct webgpu_context_struct { wgpu::Queue queue; wgpu::Limits limits; - uint32_t subgroup_size; + uint32_t max_subgroup_size; #ifndef __EMSCRIPTEN__ bool supports_subgroup_matrix = false; @@ -761,6 +757,11 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) GGML_UNUSED(ctx); #endif + + for (auto & kv : ctx->webgpu_ctx->flash_attn_pipelines) { + delete static_cast(kv.second.user_data); + kv.second.user_data = nullptr; + } } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -1148,15 +1149,18 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, }; webgpu_pipeline pipeline; + ggml_webgpu_flash_attn_shader_decisions decisions = {}; auto it = ctx->flash_attn_pipelines.find(key); if (it != ctx->flash_attn_pipelines.end()) { pipeline = it->second; + decisions = *static_cast(pipeline.user_data); } else { std::lock_guard lock(ctx->mutex); it = ctx->flash_attn_pipelines.find(key); if (it != ctx->flash_attn_pipelines.end()) { pipeline = it->second; + decisions = *static_cast(pipeline.user_data); } else { ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = ggml_type_name(K->type), @@ -1168,16 +1172,20 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .sg_mat_m = ctx->subgroup_matrix_config.M, .sg_mat_n = ctx->subgroup_matrix_config.N, .sg_mat_k = ctx->subgroup_matrix_config.K, + .wg_mem_limit_bytes = ctx->limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->max_subgroup_size }; ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.user_data = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); ctx->flash_attn_pipelines.emplace(key, pipeline); + decisions = processed.decisions; } } - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], WEBGPU_FLASH_ATTN_Q_TILE); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches //std::cout << "ggml_webgpu_flash_attn: wg_x: " << wg_x << "\n"; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -2010,7 +2018,7 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { #ifndef __EMSCRIPTEN__ if (webgpu_ctx->supports_subgroup_matrix) { std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size); + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size); sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); @@ -2683,13 +2691,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } case GGML_OP_FLASH_ATTN_EXT: { + if (!webgpu_ctx->supports_subgroup_matrix) { + break; + } + // Head dimensions must fit in workgroup memory with minimum tile sizes + size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(webgpu_ctx->subgroup_matrix_config.M, + webgpu_ctx->subgroup_matrix_config.N, + (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], + has_mask); + if (min_bytes > limit_bytes) { + break; + } + supports_op = true; // Q-type supports_op &= src0->type == GGML_TYPE_F32; // KV-type supports_op &= src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16; - // qkv sequence length - supports_op &= op->ne[0] <= 128 && src0->ne[0] <= 128; break; } case GGML_OP_RMS_NORM: @@ -2879,7 +2901,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t #endif // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->subgroup_size = info.subgroupMaxSize; + ctx->max_subgroup_size = info.subgroupMaxSize; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index edf42dc02d2..2c42ca13d37 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -11,6 +11,7 @@ enable chromium_experimental_subgroup_matrix; #define Q_TILE 16 #define KV_TILE 16 +#define WG_SIZE 64 // The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN // Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. @@ -113,8 +114,6 @@ const Q_BLOCKS = Q_TILE / SG_MAT_M; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. const KV_BLOCKS = KV_TILE / SG_MAT_N; -const WG_SIZE = 64u; - @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, From d523a40f39517100674dfc9f8b372a62243b86f2 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 31 Dec 2025 09:11:38 -0800 Subject: [PATCH 34/40] subgroup size agnostic online softmax --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 1 - .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 66 ++++++++++++------- 2 files changed, 42 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6654c4e4d6f..1b378e90f77 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -136,7 +136,6 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( // Add chosen Q/KV tile sizes auto [q_tile, kv_tile] = ggml_webgpu_flash_attn_tile_sizes(context); - kv_tile = std::min(kv_tile, 32); // TODO: temporary cap to 32 for testing defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 2c42ca13d37..6f2e043395f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -109,6 +109,21 @@ var inter_shmem: array; var row_max_shmem: array; var exp_sum_shmem: array; +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f16) -> f16 { + var v = select(FLOAT_MIN, + inter_shmem[kv_idx + q_tile_row * KV_TILE] * f16(params.scale), + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = f16(params.logit_softcap) * tanh(v); +#endif +#ifdef MASK + let mask_val = select(0.0, mask_shmem[q_tile_row * KV_TILE + kv_idx], kv_idx < KV_TILE); + let mask_term = slope * mask_val; + v += mask_term; +#endif + return v; +} + // Number of blocks this workgroup handles at the subgroup matrix level. SG_MAT_M must divide Q_TILE. const Q_BLOCKS = Q_TILE / SG_MAT_M; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. @@ -220,7 +235,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); } - // store acc to shared memory for softmax + // store acc to shared memory for softmax (S matrix from paper) let inter_offset = sg_block * SG_MAT_M * KV_TILE + kv_block * SG_MAT_N; subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); } @@ -254,42 +269,45 @@ fn main(@builtin(workgroup_id) wg_id: vec3, break; } - // calculate running max + // initialize running max for this row // only the first thread in the subgroup needs to read from shared memory. // TODO: is this faster than having all threads read shared memory? var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); prev_max = subgroupBroadcastFirst(prev_max); - var thread_tile_row_max = select(FLOAT_MIN, inter_shmem[sg_inv_id + q_tile_row * KV_TILE] * f16(params.scale), sg_inv_id < KV_TILE); -#ifdef LOGIT_SOFTCAP - thread_tile_row_max = f16(params.logit_softcap) * tanh(thread_tile_row_max); -#endif -#ifdef MASK - // The mask value for this Q row and K col - let mask_val = select(0.0, mask_shmem[q_tile_row * KV_TILE + sg_inv_id], sg_inv_id < KV_TILE); - let mask_term = slope * mask_val; - thread_tile_row_max += mask_term; -#endif - let new_max = subgroupMax(max(prev_max, thread_tile_row_max)); + var final_max = prev_max; - // calculate running exp sum - let cur_p = select(0.0, exp(thread_tile_row_max - new_max), kv_tile + sg_inv_id < params.seq_len_kv && sg_inv_id < KV_TILE); - let new_exp_term = subgroupAdd(cur_p); + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + final_max = subgroupMax(max(final_max, softmax_term)); + } - // store back to shared memory (P matrix) - if (sg_inv_id < KV_TILE) { - inter_shmem[sg_inv_id + q_tile_row * KV_TILE] = cur_p; + var total_exp_term: f16 = 0.0; + + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = cur_p; + } } - let cur_exp = exp(prev_max - new_max); + let cur_exp = exp(prev_max - final_max); + if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = new_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + new_exp_term; + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; } for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { let idx = q_tile_row * params.head_dim_v + elem_idx; - let val = o_shmem[idx] * cur_exp; - o_shmem[idx] = val; + o_shmem[idx] *= cur_exp; } } } From e72c0e4b39d16333b61f17df292da3f69c2458bc Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 31 Dec 2025 16:05:29 -0800 Subject: [PATCH 35/40] Cleanups, quantization types --- ggml/include/ggml.h | 6 - ggml/src/ggml-cpu/ggml-cpu.cpp | 6 - ggml/src/ggml-cpu/ops.cpp | 2 +- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 25 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 112 +- ggml/src/ggml-webgpu/pre_wgsl.hpp | 1188 +++++++++-------- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 139 ++ 7 files changed, 876 insertions(+), 602 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 145f37781f3..20c912d0e9b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -229,13 +229,7 @@ # define GGML_MAX_NAME 64 #endif -// For single-thread WASM builds, only use 1 thread -#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) #define GGML_DEFAULT_N_THREADS 4 -#else -#define GGML_DEFAULT_N_THREADS 1 -#endif - #define GGML_DEFAULT_GRAPH_SIZE 2048 #if UINTPTR_MAX == 0xFFFFFFFF diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index a8f098252dc..939848a6a90 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -246,11 +246,8 @@ bool ggml_backend_is_cpu(ggml_backend_t backend) { void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); -// For single-thread WASM builds, do not allow changing the number of threads -#if !defined(_EMSCRIPTEN_) || defined(__EMSCRIPTEN_PTHREADS__) struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; ctx->n_threads = n_threads; -#endif } void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) { @@ -630,13 +627,10 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) { -// For single-thread WASM builds, do not expose a set_n_threads function -#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) if (strcmp(name, "ggml_backend_set_n_threads") == 0) { ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads; return (void *)fct; } -#endif if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_cpu_device_get_extra_buffers_type; return (void *)fct; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index e3e76950e10..3032783971d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8203,7 +8203,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } S = S*ms + vs; // scale and increment sum with partial sum - } if (v->type == GGML_TYPE_F16) { @@ -8232,6 +8231,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // V /= S const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); + // dst indices const int i1 = iq1; const int i2 = iq2; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 1b378e90f77..af9f4e14de0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,6 +1,7 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml.h" #include "pre_wgsl.hpp" #include @@ -10,7 +11,7 @@ #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 64u struct ggml_webgpu_flash_attn_shader_lib_context { - const char * kv_type; + ggml_type kv_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool has_mask; @@ -100,6 +101,18 @@ static std::pair ggml_webgpu_flash_attn_tile_sizes( return best_pair; } +static const char * kv_shader_type(ggml_type kv_type) { + switch (kv_type) { + case GGML_TYPE_F32: return "f32"; + case GGML_TYPE_F16: return "f16"; + case GGML_TYPE_Q4_0: return "f16"; + case GGML_TYPE_Q8_0: return "f16"; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + return ""; + } +} + inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( pre_wgsl::Preprocessor & preprocessor, const char * shader_src, @@ -107,8 +120,14 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( std::vector defines; std::string variant = "flash_attn"; - defines.push_back(std::string("KV_TYPE=") + context.kv_type); - variant += std::string("_") + context.kv_type; + defines.push_back(std::string("KV_TYPE=") + kv_shader_type(context.kv_type)); + variant += std::string("_") + ggml_type_name(context.kv_type); + + if (context.kv_type == GGML_TYPE_Q4_0) { + defines.push_back("KV_Q4_0"); + } else if (context.kv_type == GGML_TYPE_Q8_0) { + defines.push_back("KV_Q8_0"); + } if (context.has_mask) { defines.push_back("MASK"); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3607b9d6e79..7910ec962e1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -15,17 +15,11 @@ # include #endif -#ifdef __EMSCRIPTEN__ -# include -#endif - #include #include #include #include -#include -#include #include #include #include @@ -118,6 +112,20 @@ /* End Constants */ +// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. +static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT + +// Always returns the base offset of a tensor, regardless of views. +static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; +} + +/* Struct definitions */ + +// Forward reference void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::Buffer & buffer, size_t size, @@ -245,7 +253,7 @@ struct webgpu_gpu_profile_buf_pool { struct webgpu_pipeline { wgpu::ComputePipeline pipeline; std::string name; - void * user_data = nullptr; + void * context = nullptr; }; struct webgpu_command { @@ -264,19 +272,18 @@ struct flash_attn_pipeline_key { int dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; - uint32_t n_heads; bool has_mask; bool has_sinks; bool uses_logit_softcap; bool operator==(const flash_attn_pipeline_key & other) const { return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && n_heads == other.n_heads && - has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; } }; +// Same hash combine function as in boost template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } @@ -289,7 +296,6 @@ struct flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.n_heads); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); @@ -369,7 +375,7 @@ struct webgpu_context_struct { #endif }; -using webgpu_context = std::shared_ptr; +typedef std::shared_ptr webgpu_context; struct ggml_backend_webgpu_reg_context { webgpu_context webgpu_ctx; @@ -399,17 +405,6 @@ struct ggml_backend_webgpu_buffer_context { label(std::move(lbl)) {} }; -// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. -static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT - -// Always returns the base offset of a tensor, regardless of views. -static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { - if (tensor->view_src) { - return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; - } - return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; -} - /* WebGPU object initializations */ void ggml_webgpu_create_buffer(wgpu::Device & device, @@ -763,8 +758,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #endif for (auto & kv : ctx->webgpu_ctx->flash_attn_pipelines) { - delete static_cast(kv.second.user_data); - kv.second.user_data = nullptr; + delete static_cast(kv.second.context); + kv.second.context = nullptr; } } @@ -1138,52 +1133,48 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, // .offset = 0, // .size = ctx->debug_dev_buf.GetSize() }); - GGML_ASSERT(K->type == V->type); - flash_attn_pipeline_key key = { .q_type = Q->type, .kv_type = K->type, .dst_type = dst->type, .head_dim_qk = (uint32_t) Q->ne[0], .head_dim_v = (uint32_t) V->ne[0], - .n_heads = (uint32_t) Q->ne[2], .has_mask = mask != nullptr, .has_sinks = sinks != nullptr, .uses_logit_softcap = logit_softcap != 0.0f, }; - webgpu_pipeline pipeline; + webgpu_pipeline pipeline; ggml_webgpu_flash_attn_shader_decisions decisions = {}; - auto it = ctx->flash_attn_pipelines.find(key); + auto it = ctx->flash_attn_pipelines.find(key); if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - decisions = *static_cast(pipeline.user_data); + pipeline = it->second; + decisions = *static_cast(pipeline.context); } else { std::lock_guard lock(ctx->mutex); it = ctx->flash_attn_pipelines.find(key); if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - decisions = *static_cast(pipeline.user_data); + pipeline = it->second; + decisions = *static_cast(pipeline.context); } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .kv_type = ggml_type_name(K->type), - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .has_mask = mask != nullptr, - .has_sinks = sinks != nullptr, - .uses_logit_softcap = logit_softcap != 0.0f, - .sg_mat_m = ctx->subgroup_matrix_config.M, - .sg_mat_n = ctx->subgroup_matrix_config.N, - .sg_mat_k = ctx->subgroup_matrix_config.K, - .wg_mem_limit_bytes = ctx->limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->max_subgroup_size - }; - ggml_webgpu_processed_shader processed = + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .has_mask = mask != nullptr, + .has_sinks = sinks != nullptr, + .uses_logit_softcap = logit_softcap != 0.0f, + .sg_mat_m = ctx->subgroup_matrix_config.M, + .sg_mat_n = ctx->subgroup_matrix_config.N, + .sg_mat_k = ctx->subgroup_matrix_config.K, + .wg_mem_limit_bytes = + ctx->limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->max_subgroup_size }; + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.user_data = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); + pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); ctx->flash_attn_pipelines.emplace(key, pipeline); decisions = processed.decisions; } @@ -1191,7 +1182,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - //std::cout << "ggml_webgpu_flash_attn: wg_x: " << wg_x << "\n"; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } @@ -2708,22 +2698,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } // Head dimensions must fit in workgroup memory with minimum tile sizes size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(webgpu_ctx->subgroup_matrix_config.M, - webgpu_ctx->subgroup_matrix_config.N, - (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], - has_mask); + const bool has_mask = op->src[3] != nullptr; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + webgpu_ctx->subgroup_matrix_config.M, webgpu_ctx->subgroup_matrix_config.N, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask); if (min_bytes > limit_bytes) { break; } - supports_op = true; - // Q-type - supports_op &= src0->type == GGML_TYPE_F32; - // KV-type - supports_op &= src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16; + supports_op = src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || + src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && + src2->type == src1->type && op->type == GGML_TYPE_F32; break; } case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp index 70dc19bf58c..f5cfc2aa1d1 100644 --- a/ggml/src/ggml-webgpu/pre_wgsl.hpp +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -1,15 +1,15 @@ #ifndef PRE_WGSL_HPP #define PRE_WGSL_HPP +#include +#include +#include +#include #include #include #include #include #include -#include -#include -#include -#include namespace pre_wgsl { @@ -17,25 +17,99 @@ namespace pre_wgsl { // Options //============================================================== struct Options { - std::string include_path = "."; - std::vector macros; + std::string include_path = "."; + std::vector macros; }; //============================================================== // Utility: trim //============================================================== -static inline std::string trim(const std::string& s) { - size_t a = 0; - while (a < s.size() && std::isspace((unsigned char)s[a])) a++; - size_t b = s.size(); - while (b > a && std::isspace((unsigned char)s[b - 1])) b--; - return s.substr(a, b - a); +static std::string trim(const std::string &s) { + size_t a = 0; + while (a < s.size() && std::isspace((unsigned char)s[a])) + a++; + size_t b = s.size(); + while (b > a && std::isspace((unsigned char)s[b - 1])) + b--; + return s.substr(a, b - a); +} + +static std::string trim_value(std::istream &is) { + std::string str; + std::getline(is, str); + return trim(str); +} + +static bool isIdentChar(char c) { + return std::isalnum(static_cast(c)) || c == '_'; +} + +static std::string expandMacrosRecursiveInternal( + const std::string &line, + const std::unordered_map ¯os, + std::unordered_set &visiting); + +static std::string +expandMacroValue(const std::string &name, + const std::unordered_map ¯os, + std::unordered_set &visiting) { + if (visiting.count(name)) + throw std::runtime_error("Recursive macro: " + name); + visiting.insert(name); + + auto it = macros.find(name); + if (it == macros.end()) { + visiting.erase(name); + return name; + } + + const std::string &value = it->second; + if (value.empty()) { + visiting.erase(name); + return ""; + } + + std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting); + visiting.erase(name); + return expanded; } -static inline std::string trim_value(std::istream& is) { - std::string str; - std::getline(is, str); - return trim(str); +static std::string expandMacrosRecursiveInternal( + const std::string &line, + const std::unordered_map ¯os, + std::unordered_set &visiting) { + std::string result; + result.reserve(line.size()); + + size_t i = 0; + while (i < line.size()) { + if (isIdentChar(line[i])) { + size_t start = i; + while (i < line.size() && isIdentChar(line[i])) { + i++; + } + std::string token = line.substr(start, i - start); + + auto it = macros.find(token); + if (it != macros.end()) { + result += expandMacroValue(token, macros, visiting); + } else { + result += token; + } + } else { + result += line[i]; + i++; + } + } + + return result; +} + +static std::string expandMacrosRecursive( + const std::string &line, + const std::unordered_map ¯os) { + std::unordered_set visiting; + return expandMacrosRecursiveInternal(line, macros, visiting); } //============================================================== @@ -43,74 +117,77 @@ static inline std::string trim_value(std::istream& is) { //============================================================== class ExprLexer { public: - enum Kind { - END, IDENT, NUMBER, - OP, LPAREN, RPAREN - }; - - struct Tok { - Kind kind; - std::string text; - }; - - explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} - - Tok next() { - skipWS(); - if (pos >= src.size()) return { END, "" }; - - char c = src[pos]; - - // number - if (std::isdigit((unsigned char)c)) { - size_t start = pos; - while (pos < src.size() && std::isdigit((unsigned char)src[pos])) - pos++; - return { NUMBER, std::string(src.substr(start, pos - start)) }; - } + enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN }; - // identifier - if (std::isalpha((unsigned char)c) || c == '_') { - size_t start = pos; - while (pos < src.size() && - (std::isalnum((unsigned char)src[pos]) || src[pos] == '_')) - pos++; - return { IDENT, std::string(src.substr(start, pos - start)) }; - } + struct Tok { + Kind kind; + std::string text; + }; - if (c == '(') { pos++; return { LPAREN, "(" }; } - if (c == ')') { pos++; return { RPAREN, ")" }; } - - // multi-char operators - static const char* two_ops[] = { - "==","!=", "<=", ">=", "&&","||", "<<",">>" - }; - for (auto op : two_ops) { - if (src.substr(pos, 2) == op) { - pos += 2; - return { OP, std::string(op) }; - } - } + explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {} - // single-char operators - if (std::string("+-*/%<>!").find(c) != std::string::npos) { - pos++; - return { OP, std::string(1, c) }; - } + Tok next() { + skipWS(); + if (pos >= src.size()) + return {END, ""}; + + char c = src[pos]; - // unexpected + // number + if (std::isdigit((unsigned char)c)) { + size_t start = pos; + while (pos < src.size() && std::isdigit((unsigned char)src[pos])) pos++; - return { END, "" }; + return {NUMBER, std::string(src.substr(start, pos - start))}; } -private: - std::string_view src; - size_t pos; + // identifier + if (std::isalpha((unsigned char)c) || c == '_') { + size_t start = pos; + while (pos < src.size() && + (std::isalnum((unsigned char)src[pos]) || src[pos] == '_')) + pos++; + return {IDENT, std::string(src.substr(start, pos - start))}; + } - void skipWS() { - while (pos < src.size() && std::isspace((unsigned char)src[pos])) - pos++; + if (c == '(') { + pos++; + return {LPAREN, "("}; } + if (c == ')') { + pos++; + return {RPAREN, ")"}; + } + + // multi-char operators + static const char *two_ops[] = { + "==", "!=", "<=", ">=", "&&", "||", "<<", ">>"}; + for (auto op : two_ops) { + if (src.substr(pos, 2) == op) { + pos += 2; + return {OP, std::string(op)}; + } + } + + // single-char operators + if (std::string("+-*/%<>!").find(c) != std::string::npos) { + pos++; + return {OP, std::string(1, c)}; + } + + // unexpected + pos++; + return {END, ""}; + } + +private: + std::string_view src; + size_t pos; + + void skipWS() { + while (pos < src.size() && std::isspace((unsigned char)src[pos])) + pos++; + } }; //============================================================== @@ -118,172 +195,214 @@ class ExprLexer { //============================================================== class ExprParser { public: - ExprParser(std::string_view expr, - const std::unordered_map& macros) - : lex(expr), macros(macros) - { - advance(); - } + ExprParser(std::string_view expr, + const std::unordered_map ¯os, + std::unordered_set &visiting) + : lex(expr), macros(macros), visiting(visiting) { + advance(); + } - int parse() { - return parseLogicalOr(); - } + int parse() { return parseLogicalOr(); } private: - ExprLexer lex; - ExprLexer::Tok tok; - const std::unordered_map& macros; + ExprLexer lex; + ExprLexer::Tok tok; + const std::unordered_map ¯os; + std::unordered_set &visiting; - void advance() { tok = lex.next(); } + void advance() { tok = lex.next(); } - bool acceptOp(const std::string& s) { - if (tok.kind == ExprLexer::OP && tok.text == s) { - advance(); - return true; - } - return false; + bool acceptOp(const std::string &s) { + if (tok.kind == ExprLexer::OP && tok.text == s) { + advance(); + return true; } + return false; + } - bool acceptKind(ExprLexer::Kind k) { - if (tok.kind == k) { - advance(); - return true; - } - return false; + bool acceptKind(ExprLexer::Kind k) { + if (tok.kind == k) { + advance(); + return true; } - - int parseLogicalOr() { - int v = parseLogicalAnd(); - while (acceptOp("||")) { - int rhs = parseLogicalAnd(); - v = (v || rhs); - } - return v; + return false; + } + + int parseLogicalOr() { + int v = parseLogicalAnd(); + while (acceptOp("||")) { + int rhs = parseLogicalAnd(); + v = (v || rhs); } - - int parseLogicalAnd() { - int v = parseEquality(); - while (acceptOp("&&")) { - int rhs = parseEquality(); - v = (v && rhs); - } - return v; - } - - int parseEquality() { - int v = parseRelational(); - for (;;) { - if (acceptOp("==")) { - int rhs = parseRelational(); - v = (v == rhs); - } else if (acceptOp("!=")) { - int rhs = parseRelational(); - v = (v != rhs); - } else break; - } - return v; + return v; + } + + int parseLogicalAnd() { + int v = parseEquality(); + while (acceptOp("&&")) { + int rhs = parseEquality(); + v = (v && rhs); } - - int parseRelational() { - int v = parseShift(); - for (;;) { - if (acceptOp("<")) { int rhs = parseShift(); v = (v < rhs); } - else if (acceptOp(">")) { int rhs = parseShift(); v = (v > rhs); } - else if (acceptOp("<=")){ int rhs = parseShift(); v = (v <= rhs); } - else if (acceptOp(">=")){ int rhs = parseShift(); v = (v >= rhs); } - else break; - } - return v; + return v; + } + + int parseEquality() { + int v = parseRelational(); + for (;;) { + if (acceptOp("==")) { + int rhs = parseRelational(); + v = (v == rhs); + } else if (acceptOp("!=")) { + int rhs = parseRelational(); + v = (v != rhs); + } else + break; } - - int parseShift() { - int v = parseAdd(); - for (;;) { - if (acceptOp("<<")) { int rhs = parseAdd(); v = (v << rhs); } - else if (acceptOp(">>")) { int rhs = parseAdd(); v = (v >> rhs); } - else break; - } - return v; + return v; + } + + int parseRelational() { + int v = parseShift(); + for (;;) { + if (acceptOp("<")) { + int rhs = parseShift(); + v = (v < rhs); + } else if (acceptOp(">")) { + int rhs = parseShift(); + v = (v > rhs); + } else if (acceptOp("<=")) { + int rhs = parseShift(); + v = (v <= rhs); + } else if (acceptOp(">=")) { + int rhs = parseShift(); + v = (v >= rhs); + } else + break; } - - int parseAdd() { - int v = parseMult(); - for (;;) { - if (acceptOp("+")) { int rhs = parseMult(); v = (v + rhs); } - else if (acceptOp("-")) { int rhs = parseMult(); v = (v - rhs); } - else break; - } - return v; + return v; + } + + int parseShift() { + int v = parseAdd(); + for (;;) { + if (acceptOp("<<")) { + int rhs = parseAdd(); + v = (v << rhs); + } else if (acceptOp(">>")) { + int rhs = parseAdd(); + v = (v >> rhs); + } else + break; } - - int parseMult() { - int v = parseUnary(); - for (;;) { - if (acceptOp("*")) { int rhs = parseUnary(); v = (v * rhs); } - else if (acceptOp("/")) { int rhs = parseUnary(); v = (rhs == 0 ? 0 : v / rhs); } - else if (acceptOp("%")) { int rhs = parseUnary(); v = (rhs == 0 ? 0 : v % rhs); } - else break; - } - return v; + return v; + } + + int parseAdd() { + int v = parseMult(); + for (;;) { + if (acceptOp("+")) { + int rhs = parseMult(); + v = (v + rhs); + } else if (acceptOp("-")) { + int rhs = parseMult(); + v = (v - rhs); + } else + break; + } + return v; + } + + int parseMult() { + int v = parseUnary(); + for (;;) { + if (acceptOp("*")) { + int rhs = parseUnary(); + v = (v * rhs); + } else if (acceptOp("/")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v / rhs); + } else if (acceptOp("%")) { + int rhs = parseUnary(); + v = (rhs == 0 ? 0 : v % rhs); + } else + break; + } + return v; + } + + int parseUnary() { + if (acceptOp("!")) + return !parseUnary(); + if (acceptOp("-")) + return -parseUnary(); + if (acceptOp("+")) + return +parseUnary(); + return parsePrimary(); + } + + int parsePrimary() { + // '(' expr ')' + if (acceptKind(ExprLexer::LPAREN)) { + int v = parse(); + if (!acceptKind(ExprLexer::RPAREN)) + throw std::runtime_error("missing ')'"); + return v; } - int parseUnary() { - if (acceptOp("!")) return !parseUnary(); - if (acceptOp("-")) return -parseUnary(); - if (acceptOp("+")) return +parseUnary(); - return parsePrimary(); + // number + if (tok.kind == ExprLexer::NUMBER) { + int v = std::stoi(tok.text); + advance(); + return v; } - int parsePrimary() { - // '(' expr ')' - if (acceptKind(ExprLexer::LPAREN)) { - int v = parse(); - if (!acceptKind(ExprLexer::RPAREN)) - throw std::runtime_error("missing ')'"); - return v; - } + // defined(identifier) + if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { + advance(); + if (acceptKind(ExprLexer::LPAREN)) { + if (tok.kind != ExprLexer::IDENT) + throw std::runtime_error("expected identifier in defined()"); + std::string name = tok.text; + advance(); + if (!acceptKind(ExprLexer::RPAREN)) + throw std::runtime_error("missing ) in defined()"); + return macros.count(name) ? 1 : 0; + } else { + // defined NAME + if (tok.kind != ExprLexer::IDENT) + throw std::runtime_error("expected identifier in defined NAME"); + std::string name = tok.text; + advance(); + return macros.count(name) ? 1 : 0; + } + } - // number - if (tok.kind == ExprLexer::NUMBER) { - int v = std::stoi(tok.text); - advance(); - return v; - } + // identifier -> treat as integer, if defined use its value else 0 + if (tok.kind == ExprLexer::IDENT) { + std::string name = tok.text; + advance(); + auto it = macros.find(name); + if (it == macros.end()) + return 0; + if (it->second.empty()) + return 1; + return evalMacroExpression(name, it->second); + } - // defined(identifier) - if (tok.kind == ExprLexer::IDENT && tok.text == "defined") { - advance(); - if (acceptKind(ExprLexer::LPAREN)) { - if (tok.kind != ExprLexer::IDENT) - throw std::runtime_error("expected identifier in defined()"); - std::string name = tok.text; - advance(); - if (!acceptKind(ExprLexer::RPAREN)) - throw std::runtime_error("missing ) in defined()"); - return macros.count(name) ? 1 : 0; - } else { - // defined NAME - if (tok.kind != ExprLexer::IDENT) - throw std::runtime_error("expected identifier in defined NAME"); - std::string name = tok.text; - advance(); - return macros.count(name) ? 1 : 0; - } - } + // unexpected + return 0; + } - // identifier -> treat as integer, if defined use its value else 0 - if (tok.kind == ExprLexer::IDENT) { - std::string name = tok.text; - advance(); - auto it = macros.find(name); - if (it == macros.end()) return 0; - if (it->second.empty()) return 1; - return std::stoi(it->second); - } + int evalMacroExpression(const std::string &name, const std::string &value) { + if (visiting.count(name)) + throw std::runtime_error("Recursive macro: " + name); - // unexpected - return 0; - } + visiting.insert(name); + ExprParser ep(value, macros, visiting); + int v = ep.parse(); + visiting.erase(name); + return v; + } }; //============================================================== @@ -291,327 +410,350 @@ class ExprParser { //============================================================== class Preprocessor { public: - explicit Preprocessor(Options opts = {}) - : opts_(std::move(opts)) { - // Treat empty include path as current directory - if (opts_.include_path.empty()) { - opts_.include_path = "."; - } - parseMacroDefinitions(opts_.macros); - } - - std::string preprocess_file(const std::string& filename, - const std::vector& additional_macros = {}) { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - buildMacros(additional_macros, macros, predefined); - - std::string result = processFile(filename, macros, predefined, include_stack); - return result; - } - - std::string preprocess(const std::string& contents, - const std::vector& additional_macros = {}) - { - std::unordered_map macros; - std::unordered_set predefined; - std::unordered_set include_stack; - buildMacros(additional_macros, macros, predefined); - - std::string result = processString(contents, macros, predefined, include_stack); - return result; + explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) { + // Treat empty include path as current directory + if (opts_.include_path.empty()) { + opts_.include_path = "."; } + parseMacroDefinitions(opts_.macros); + } + + std::string + preprocess_file(const std::string &filename, + const std::vector &additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processFile(filename, macros, predefined, + include_stack, DirectiveMode::All); + return result; + } + + std::string + preprocess(const std::string &contents, + const std::vector &additional_macros = {}) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + buildMacros(additional_macros, macros, predefined); + + std::string result = processString(contents, macros, predefined, + include_stack, DirectiveMode::All); + return result; + } + + std::string preprocess_includes_file(const std::string &filename) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = + processFile(filename, macros, predefined, include_stack, + DirectiveMode::IncludesOnly); + return result; + } + + std::string preprocess_includes(const std::string &contents) { + std::unordered_map macros; + std::unordered_set predefined; + std::unordered_set include_stack; + std::string result = + processString(contents, macros, predefined, include_stack, + DirectiveMode::IncludesOnly); + return result; + } private: - Options opts_; - std::unordered_map global_macros; - - struct Cond { - bool parent_active; - bool active; - bool taken; - }; - - //---------------------------------------------------------- - // Parse macro definitions into global_macros - //---------------------------------------------------------- - void parseMacroDefinitions(const std::vector& macro_defs) { - for (const auto& def : macro_defs) { - size_t eq_pos = def.find('='); - if (eq_pos != std::string::npos) { - // Format: NAME=VALUE - std::string name = trim(def.substr(0, eq_pos)); - std::string value = trim(def.substr(eq_pos + 1)); - global_macros[name] = value; - } else { - // Format: NAME - std::string name = trim(def); - global_macros[name] = ""; - } - } + Options opts_; + std::unordered_map global_macros; + + enum class DirectiveMode { All, IncludesOnly }; + + struct Cond { + bool parent_active; + bool active; + bool taken; + }; + + //---------------------------------------------------------- + // Parse macro definitions into global_macros + //---------------------------------------------------------- + void parseMacroDefinitions(const std::vector ¯o_defs) { + for (const auto &def : macro_defs) { + size_t eq_pos = def.find('='); + if (eq_pos != std::string::npos) { + // Format: NAME=VALUE + std::string name = trim(def.substr(0, eq_pos)); + std::string value = trim(def.substr(eq_pos + 1)); + global_macros[name] = value; + } else { + // Format: NAME + std::string name = trim(def); + global_macros[name] = ""; + } } - - //---------------------------------------------------------- - // Build combined macro map and predefined set for a preprocessing operation - //---------------------------------------------------------- - void buildMacros( - const std::vector& additional_macros, - std::unordered_map& macros, - std::unordered_set& predefined) { - macros = global_macros; - predefined.clear(); - - for (const auto& [name, value] : global_macros) { - predefined.insert(name); - } - - for (const auto& def : additional_macros) { - size_t eq_pos = def.find('='); - std::string name, value; - if (eq_pos != std::string::npos) { - name = trim(def.substr(0, eq_pos)); - value = trim(def.substr(eq_pos + 1)); - } else { - name = trim(def); - value = ""; - } - - // Add to macros map (will override global if same name) - macros[name] = value; - predefined.insert(name); - } + } + + //---------------------------------------------------------- + // Build combined macro map and predefined set for a preprocessing operation + //---------------------------------------------------------- + void buildMacros(const std::vector &additional_macros, + std::unordered_map ¯os, + std::unordered_set &predefined) { + macros = global_macros; + predefined.clear(); + + for (const auto &[name, value] : global_macros) { + predefined.insert(name); } - //---------------------------------------------------------- - // Helpers - //---------------------------------------------------------- - std::string loadFile(const std::string& fname) { - std::ifstream f(fname); - if (!f.is_open()) - throw std::runtime_error("Could not open file: " + fname); - std::stringstream ss; - ss << f.rdbuf(); - return ss.str(); - } - - bool condActive(const std::vector& cond) const { - if (cond.empty()) return true; - return cond.back().active; - } - - //---------------------------------------------------------- - // Helper to check if a character can be part of an identifier - //---------------------------------------------------------- - static bool isIdent(char c) { - return std::isalnum(static_cast(c)) || c == '_'; - } - - //---------------------------------------------------------- - // Expand macros in a line of code - //---------------------------------------------------------- - std::string expandMacros(const std::string& line, - const std::unordered_map& macros) { - std::string result; - result.reserve(line.size()); - - size_t i = 0; - while (i < line.size()) { - if (isIdent(line[i])) { - size_t start = i; - while (i < line.size() && isIdent(line[i])) { - i++; - } - std::string token = line.substr(start, i - start); - - auto it = macros.find(token); - if (it != macros.end()) { - result += it->second; - } else { - result += token; - } - } else { - result += line[i]; - i++; - } - } - - return result; - } - - //---------------------------------------------------------- - // Process a file - //---------------------------------------------------------- - std::string processFile(const std::string& name, - std::unordered_map& macros, - const std::unordered_set& predefined_macros, - std::unordered_set& include_stack) { - if (include_stack.count(name)) - throw std::runtime_error("Recursive include: " + name); - - include_stack.insert(name); - std::string shader_code = loadFile(name); - std::string out = processString(shader_code, macros, predefined_macros, include_stack); - include_stack.erase(name); - return out; - } - - std::string processIncludeFile(const std::string& fname, - std::unordered_map& macros, - const std::unordered_set& predefined_macros, - std::unordered_set& include_stack) { - std::string full_path = opts_.include_path + "/" + fname; - return processFile(full_path, macros, predefined_macros, include_stack); - } - - //---------------------------------------------------------- - // Process text - //---------------------------------------------------------- - std::string processString(const std::string& shader_code, - std::unordered_map& macros, - const std::unordered_set& predefined_macros, - std::unordered_set& include_stack) - { - std::vector cond; // Conditional stack for this shader - std::stringstream out; - std::istringstream in(shader_code); - std::string line; - - while (std::getline(in, line)) { - std::string t = trim(line); - - if (!t.empty() && t[0] == '#') { - handleDirective(t, out, macros, predefined_macros, cond, include_stack); - } else { - if (condActive(cond)) { - // Expand macros in the line before outputting - std::string expanded = expandMacros(line, macros); - out << expanded << "\n"; - } - } - } - - if (!cond.empty()) - throw std::runtime_error("Unclosed #if directive"); - - return out.str(); - } - - //---------------------------------------------------------- - // Directive handler - //---------------------------------------------------------- - void handleDirective(const std::string& t, std::stringstream& out, - std::unordered_map& macros, - const std::unordered_set& predefined_macros, - std::vector& cond, - std::unordered_set& include_stack) { - // split into tokens - std::string body = t.substr(1); - std::istringstream iss(body); - std::string cmd; - iss >> cmd; - - if (cmd == "include") { - if (!condActive(cond)) return; - std::string file; - iss >> file; - if (file.size() >= 2 && file.front()=='"' && file.back()=='"') - file = file.substr(1, file.size()-2); - out << processIncludeFile(file, macros, predefined_macros, include_stack); - return; - } - - if (cmd == "define") { - if (!condActive(cond)) return; - std::string name; - iss >> name; - // Don't override predefined macros from options - if (predefined_macros.count(name)) return; - std::string value = trim_value(iss); - macros[name] = value; - return; - } - - if (cmd == "ifdef") { - std::string name; iss >> name; - bool p = condActive(cond); - bool v = macros.count(name); - cond.push_back({p, p && v, p && v}); - return; + for (const auto &def : additional_macros) { + size_t eq_pos = def.find('='); + std::string name, value; + if (eq_pos != std::string::npos) { + name = trim(def.substr(0, eq_pos)); + value = trim(def.substr(eq_pos + 1)); + } else { + name = trim(def); + value = ""; + } + + // Add to macros map (will override global if same name) + macros[name] = value; + predefined.insert(name); + } + } + + //---------------------------------------------------------- + // Helpers + //---------------------------------------------------------- + std::string loadFile(const std::string &fname) { + std::ifstream f(fname); + if (!f.is_open()) + throw std::runtime_error("Could not open file: " + fname); + std::stringstream ss; + ss << f.rdbuf(); + return ss.str(); + } + + bool condActive(const std::vector &cond) const { + if (cond.empty()) + return true; + return cond.back().active; + } + + //---------------------------------------------------------- + // Process a file + //---------------------------------------------------------- + std::string + processFile(const std::string &name, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::unordered_set &include_stack, + DirectiveMode mode) { + if (include_stack.count(name)) + throw std::runtime_error("Recursive include: " + name); + + include_stack.insert(name); + std::string shader_code = loadFile(name); + std::string out = processString(shader_code, macros, predefined_macros, + include_stack, mode); + include_stack.erase(name); + return out; + } + + std::string + processIncludeFile(const std::string &fname, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::unordered_set &include_stack, + DirectiveMode mode) { + std::string full_path = opts_.include_path + "/" + fname; + return processFile(full_path, macros, predefined_macros, include_stack, + mode); + } + + //---------------------------------------------------------- + // Process text + //---------------------------------------------------------- + std::string + processString(const std::string &shader_code, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::unordered_set &include_stack, + DirectiveMode mode) { + std::vector cond; // Conditional stack for this shader + std::stringstream out; + std::istringstream in(shader_code); + std::string line; + + while (std::getline(in, line)) { + std::string t = trim(line); + + if (!t.empty() && t[0] == '#') { + bool handled = handleDirective(t, out, macros, predefined_macros, cond, + include_stack, mode); + if (mode == DirectiveMode::IncludesOnly && !handled) { + out << line << "\n"; } - - if (cmd == "ifndef") { - std::string name; iss >> name; - bool p = condActive(cond); - bool v = !macros.count(name); - cond.push_back({p, p && v, p && v}); - return; + } else { + if (mode == DirectiveMode::IncludesOnly) { + out << line << "\n"; + } else if (condActive(cond)) { + // Expand macros in the line before outputting + std::string expanded = expandMacrosRecursive(line, macros); + out << expanded << "\n"; } + } + } - if (cmd == "if") { - std::string expr = trim_value(iss); - bool p = condActive(cond); - bool v = false; - if (p) { - ExprParser ep(expr, macros); - v = ep.parse() != 0; - } - cond.push_back({p, p && v, p && v}); - return; - } + if (mode == DirectiveMode::All && !cond.empty()) + throw std::runtime_error("Unclosed #if directive"); + + return out.str(); + } + + //---------------------------------------------------------- + // Directive handler + //---------------------------------------------------------- + bool handleDirective(const std::string &t, std::stringstream &out, + std::unordered_map ¯os, + const std::unordered_set &predefined_macros, + std::vector &cond, + std::unordered_set &include_stack, + DirectiveMode mode) { + // split into tokens + std::string body = t.substr(1); + std::istringstream iss(body); + std::string cmd; + iss >> cmd; + + if (cmd == "include") { + if (mode == DirectiveMode::All && !condActive(cond)) + return true; + std::string file; + iss >> file; + if (file.size() >= 2 && file.front() == '"' && file.back() == '"') + file = file.substr(1, file.size() - 2); + out << processIncludeFile(file, macros, predefined_macros, include_stack, + mode); + return true; + } - if (cmd == "elif") { - std::string expr = trim_value(iss); + if (mode == DirectiveMode::IncludesOnly) + return false; + + if (cmd == "define") { + if (!condActive(cond)) + return true; + std::string name; + iss >> name; + // Don't override predefined macros from options + if (predefined_macros.count(name)) + return true; + std::string value = trim_value(iss); + macros[name] = value; + return true; + } - if (cond.empty()) - throw std::runtime_error("#elif without #if"); + if (cmd == "undef") { + if (!condActive(cond)) + return true; + std::string name; + iss >> name; + // Don't undef predefined macros from options + if (predefined_macros.count(name)) + return true; + macros.erase(name); + return true; + } - Cond& c = cond.back(); - if (!c.parent_active) { - c.active = false; - return; - } + if (cmd == "ifdef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = macros.count(name); + cond.push_back({p, p && v, p && v}); + return true; + } - if (c.taken) { - c.active = false; - return; - } + if (cmd == "ifndef") { + std::string name; + iss >> name; + bool p = condActive(cond); + bool v = !macros.count(name); + cond.push_back({p, p && v, p && v}); + return true; + } - ExprParser ep(expr, macros); - bool v = ep.parse() != 0; - c.active = v; - if (v) c.taken = true; - return; - } + if (cmd == "if") { + std::string expr = trim_value(iss); + bool p = condActive(cond); + bool v = false; + if (p) { + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + v = ep.parse() != 0; + } + cond.push_back({p, p && v, p && v}); + return true; + } - if (cmd == "else") { - if (cond.empty()) - throw std::runtime_error("#else without #if"); - - Cond& c = cond.back(); - if (!c.parent_active) { - c.active = false; - return; - } - if (c.taken) { - c.active = false; - } else { - c.active = true; - c.taken = true; - } - return; - } + if (cmd == "elif") { + std::string expr = trim_value(iss); + + if (cond.empty()) + throw std::runtime_error("#elif without #if"); + + Cond &c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + + if (c.taken) { + c.active = false; + return true; + } + + std::unordered_set visiting; + ExprParser ep(expr, macros, visiting); + bool v = ep.parse() != 0; + c.active = v; + if (v) + c.taken = true; + return true; + } - if (cmd == "endif") { - if (cond.empty()) - throw std::runtime_error("#endif without #if"); - cond.pop_back(); - return; - } + if (cmd == "else") { + if (cond.empty()) + throw std::runtime_error("#else without #if"); + + Cond &c = cond.back(); + if (!c.parent_active) { + c.active = false; + return true; + } + if (c.taken) { + c.active = false; + } else { + c.active = true; + c.taken = true; + } + return true; + } - // Unknown directive - throw std::runtime_error("Unknown directive: #" + cmd); + if (cmd == "endif") { + if (cond.empty()) + throw std::runtime_error("#endif without #if"); + cond.pop_back(); + return true; } + + // Unknown directive + throw std::runtime_error("Unknown directive: #" + cmd); + } }; } // namespace pre_wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 6f2e043395f..9a19482284a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -19,6 +19,32 @@ enable chromium_experimental_subgroup_matrix; #define SG_MAT_N 8 #define SG_MAT_K 8 +// Quantization constants/helpers +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +// number of quantized elements processed per thread +#if defined(KV_Q4_0) +#define NQ 16 +// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + struct Params { offset_q: u32, offset_k: u32, @@ -195,7 +221,63 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + // load k tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * params.head_dim_qk; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * params.head_dim_qk; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#else for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { let k_row = elem_idx / params.head_dim_qk; let k_col = elem_idx % params.head_dim_qk; @@ -206,6 +288,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, K[global_k_row_offset + k_col], global_k_row < params.seq_len_kv && k_col < params.head_dim_qk)); } +#endif workgroupBarrier(); @@ -313,6 +396,61 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // load v tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * params.head_dim_v; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * params.head_dim_v; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; // scale + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#else for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE) { let v_row = elem_idx / params.head_dim_v; let v_col = elem_idx % params.head_dim_v; @@ -323,6 +461,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, V[global_v_row_offset + v_col], global_v_row < params.seq_len_kv && v_col < params.head_dim_v)); } +#endif workgroupBarrier(); From e36c9cd225aeda10155d2eef5e43e338daeb1e46 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 31 Dec 2025 16:13:22 -0800 Subject: [PATCH 36/40] more cleanup --- ggml/src/ggml-cpu/ggml-cpu.cpp | 1 - ggml/src/ggml-webgpu/ggml-webgpu.cpp | 74 ++++++------------- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 7 -- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 5 +- tests/test-backend-ops.cpp | 6 -- 5 files changed, 23 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 939848a6a90..f4713a42185 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -626,7 +626,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r } static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (strcmp(name, "ggml_backend_set_n_threads") == 0) { ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads; return (void *)fct; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 7910ec962e1..10fe811e1a2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -126,11 +126,11 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { /* Struct definitions */ // Forward reference -void ggml_webgpu_create_buffer(wgpu::Device & device, - wgpu::Buffer & buffer, - size_t size, - wgpu::BufferUsage usage, - const char * label); +static void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label); struct webgpu_pool_bufs { wgpu::Buffer host_buf; @@ -407,21 +407,6 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ -void ggml_webgpu_create_buffer(wgpu::Device & device, - wgpu::Buffer & buffer, - size_t size, - wgpu::BufferUsage usage, - const char * label) { - wgpu::BufferDescriptor buffer_desc; - buffer_desc.size = size; - buffer_desc.usage = usage; - buffer_desc.label = label; - buffer_desc.mappedAtCreation = false; - - // TODO: error handling - buffer = device.CreateBuffer(&buffer_desc); -} - // Process a WGSL shader string, replacing tokens of the form {{KEY}} with // the corresponding values provided in `repls`. static std::string ggml_webgpu_process_shader_repls(const char * src, @@ -465,6 +450,21 @@ static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & return { device.CreateComputePipeline(&pipeline_desc), label }; } +static void ggml_webgpu_create_buffer(wgpu::Device & device, + wgpu::Buffer & buffer, + size_t size, + wgpu::BufferUsage usage, + const char * label) { + wgpu::BufferDescriptor buffer_desc; + buffer_desc.size = size; + buffer_desc.usage = usage; + buffer_desc.label = label; + buffer_desc.mappedAtCreation = false; + + // TODO: error handling + buffer = device.CreateBuffer(&buffer_desc); +} + /** End WebGPU object initializations */ /** WebGPU Actions */ @@ -1050,20 +1050,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, const int has_mask = (mask != nullptr); const int has_sinks = (sinks != nullptr); - // print type and dimensions of Q/K/V/mask/sinks/dst - // std::cout << "ggml_webgpu_flash_attn: Q type: " << ggml_type_name(Q->type) << ", ne: [" << Q->ne[0] << ", " << Q->ne[1] << ", " << Q->ne[2] - // << ", " << Q->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: K type: " << ggml_type_name(K->type) << ", ne: [" << K->ne[0] << ", " << K->ne[1] << ", " << K->ne[2] - // << ", " << K->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: V type: " << ggml_type_name(V->type) << ", ne: [" << V->ne[0] << ", " << V->ne[1] << ", " << V->ne[2] - // << ", " << V->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: mask type: " << ggml_type_name(mask->type) << ", ne: [" << mask->ne[0] << ", " << mask->ne[1] << ", " << mask->ne[2] - // << ", " << mask->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: sinks type: " << ggml_type_name(sinks->type) << ", ne: [" << sinks->ne[0] << ", " << sinks->ne[1] << ", " << sinks->ne[2] - // << ", " << sinks->ne[3] << "]\n"; - // std::cout << "ggml_webgpu_flash_attn: dst type: " << ggml_type_name(dst->type) << ", ne: [" << dst->ne[0] << ", " << dst->ne[1] << ", " << dst->ne[2] - // << ", " << dst->ne[3] << "]\n"; - std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), @@ -1126,12 +1112,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - // Debug buffer binding (for development only) - // entries.push_back( - // { .binding = binding_index, - // .buffer = ctx->debug_dev_buf, - // .offset = 0, - // .size = ctx->debug_dev_buf.GetSize() }); flash_attn_pipeline_key key = { .q_type = Q->type, @@ -1170,9 +1150,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .wg_mem_limit_bytes = ctx->limits.maxComputeWorkgroupStorageSize, .max_subgroup_size = ctx->max_subgroup_size }; - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); ctx->flash_attn_pipelines.emplace(key, pipeline); @@ -1647,12 +1627,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str std::vector commands; std::vector futures; - bool contains_flash_attn = false; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { - if (cgraph->nodes[i]->op == GGML_OP_FLASH_ATTN_EXT) { - contains_flash_attn = true; - } commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads @@ -1672,12 +1648,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str futures.push_back(new_futures); } -#ifdef GGML_WEBGPU_DEBUG - if (contains_flash_attn) { - ggml_backend_webgpu_debug(ctx); - } -#endif - ggml_backend_webgpu_wait(ctx, futures); ctx->inflight_threads--; WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 4df49759f7b..d61df5bb9e5 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -18,13 +18,6 @@ def parse_decls(decls_text): decls[name.strip()] = code.strip() return decls -def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): - for key, val in template_map.items(): - # Match "key" and avoid matching subsequences using by using \b - code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) - variant["REPLS"][repl] = code - return variant def replace_repl_placeholders(variant, template_map): for repl, code in variant["REPLS"].items(): diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 9a19482284a..963a485d877 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -107,7 +107,6 @@ struct Params { #endif @group(0) @binding(DST_BINDING) var dst: array; -//@group(0) @binding(6) var debug: array; @group(0) @binding(PARAMS_BINDING) var params: Params; const FLOAT_MIN: f16 = -65504.0; @@ -163,8 +162,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { - //debug[0] = 42; - // initialize row max for online softmax for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { row_max_shmem[i] = FLOAT_MIN; @@ -523,7 +520,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); prev_max = subgroupBroadcastFirst(prev_max); - // for non-sink threads, exp(-65504) effectively zeroes out their contrinbution to the sum + // for non-sink threads, exp(-65504) effectively zeroes out their contribution to the sum let sink_val = select(FLOAT_MIN, f16(sinks[params.offset_sinks + head_idx]), sg_inv_id == 0); let new_max = subgroupMax(max(prev_max, sink_val)); let max_exp = exp(prev_max - new_max); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d0044c57405..0b981b17883 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -47,12 +47,6 @@ # define N_THREADS std::thread::hardware_concurrency() #endif -#ifdef __EMSCRIPTEN__ -# define N_THREADS 1 -#else -# define N_THREADS std::thread::hardware_concurrency() -#endif - static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) { size_t nels = ggml_nelements(tensor); std::vector data(nels); From ef5fd1be0d643b5d676c375b3e877ff54fe4a993 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 31 Dec 2025 16:24:59 -0800 Subject: [PATCH 37/40] fix wasm build --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 30 +++++++++++++++------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 10fe811e1a2..4204b41c518 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -313,10 +313,10 @@ struct webgpu_context_struct { uint32_t max_subgroup_size; -#ifndef __EMSCRIPTEN__ bool supports_subgroup_matrix = false; - wgpu::SubgroupMatrixConfig subgroup_matrix_config; -#endif + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; std::recursive_mutex mutex; std::atomic_uint inflight_threads = 0; @@ -1007,10 +1007,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. uint32_t wg_m_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M; + WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); uint32_t wg_n_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N; + WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif @@ -1095,7 +1095,7 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, V), .size = ggml_webgpu_tensor_binding_size(ctx, V) } }; - uint binding_index = 3; + uint32_t binding_index = 3; if (has_mask) { entries.push_back({ .binding = binding_index++, .buffer = ggml_webgpu_tensor_buf(mask), @@ -1144,9 +1144,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .has_mask = mask != nullptr, .has_sinks = sinks != nullptr, .uses_logit_softcap = logit_softcap != 0.0f, - .sg_mat_m = ctx->subgroup_matrix_config.M, - .sg_mat_n = ctx->subgroup_matrix_config.N, - .sg_mat_k = ctx->subgroup_matrix_config.K, + .sg_mat_m = ctx->sg_mat_m, + .sg_mat_n = ctx->sg_mat_n, + .sg_mat_k = ctx->sg_mat_k, .wg_mem_limit_bytes = ctx->limits.maxComputeWorkgroupStorageSize, .max_subgroup_size = ctx->max_subgroup_size }; @@ -1996,9 +1996,9 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K); + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k); proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); proc_mul_mat_f32_f32_vec = @@ -2670,7 +2670,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - webgpu_ctx->subgroup_matrix_config.M, webgpu_ctx->subgroup_matrix_config.N, (uint32_t) src0->ne[0], + webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask); if (min_bytes > limit_bytes) { break; @@ -2857,7 +2857,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->subgroup_matrix_config = config; + ctx->sg_mat_m = config.M; + ctx->sg_mat_n = config.N; + ctx->sg_mat_k = config.K; valid_subgroup_matrix_config = true; break; } From 2fc8060794ec910aac1eec5f90b66488ec9be4a8 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 2 Jan 2026 22:48:33 -0800 Subject: [PATCH 38/40] Refactor flashattention to increase parallelism, use direct loads for KV in somce cases --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 123 ++++------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 47 ++-- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 212 ++++++++++-------- 3 files changed, 195 insertions(+), 187 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index af9f4e14de0..aad88d8a3d3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -7,21 +7,23 @@ #include #include -#define GGML_WEBGPU_F16_SIZE_BYTES 2 -#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 64u +#define GGML_WEBGPU_F16_SIZE_BYTES 2 +#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_type kv_type; - uint32_t head_dim_qk; - uint32_t head_dim_v; - bool has_mask; - bool has_sinks; - bool uses_logit_softcap; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; + ggml_type kv_type; + uint32_t head_dim_qk; + uint32_t head_dim_v; + bool kv_direct; + bool has_mask; + bool has_sinks; + bool uses_logit_softcap; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; + size_t wg_mem_limit_bytes; + uint32_t max_subgroup_size; }; struct ggml_webgpu_flash_attn_shader_decisions { @@ -56,61 +58,15 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return elems * GGML_WEBGPU_F16_SIZE_BYTES; } -// Returns a pair of (q_tile, kv_tile) that best fits within the workgroup memory limit -// Currently set to prefer the configuration that comes closest to using half of the limit -// Assumes that the base minimum tile sizes fits within the limit -static std::pair ggml_webgpu_flash_attn_tile_sizes( - const ggml_webgpu_flash_attn_shader_lib_context & context) { - std::pair best_pair = { 0, 0 }; - size_t best_delta = 0; - - const uint32_t min_q_tile = context.sg_mat_m; - const uint32_t min_kv_tile = context.sg_mat_n; - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t target_bytes = limit_bytes / 2; - const uint32_t max_head_dim = std::max(context.head_dim_qk, context.head_dim_v); - - // These sizes come from the equations for wg_mem_bytes, solving for q_tile or kv_tile respectively - const size_t base_kv_bytes = min_kv_tile * max_head_dim * GGML_WEBGPU_F16_SIZE_BYTES; - const size_t bytes_per_q = - (context.head_dim_qk + context.head_dim_v + (context.has_mask ? min_kv_tile : 0) + min_kv_tile + 2) * +static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v + 2) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES; + const size_t bytes_per_kv = + (std::max(context.head_dim_qk, context.head_dim_v) + (context.has_mask ? q_tile : 0) + q_tile) * GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_q_tile = (limit_bytes - base_kv_bytes) / bytes_per_q; - - const size_t base_q_bytes = - (context.head_dim_qk + context.head_dim_v + 2) * min_q_tile * GGML_WEBGPU_F16_SIZE_BYTES; - const size_t bytes_per_kv = - (max_head_dim + (context.has_mask ? min_q_tile : 0) + min_q_tile) * GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - - // step by minimum tile sizes - for (uint32_t q = min_q_tile; q <= max_q_tile; q += min_q_tile) { - for (uint32_t kv = min_kv_tile; kv <= max_kv_tile; kv += min_kv_tile) { - size_t bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(q, kv, context.head_dim_qk, context.head_dim_v, context.has_mask); - if (bytes <= limit_bytes) { - size_t delta = bytes > target_bytes ? bytes - target_bytes : target_bytes - bytes; - if (best_pair.first == 0 || delta < best_delta) { - best_pair = { q, kv }; - best_delta = delta; - } - } - } - } - - return best_pair; -} - -static const char * kv_shader_type(ggml_type kv_type) { - switch (kv_type) { - case GGML_TYPE_F32: return "f32"; - case GGML_TYPE_F16: return "f16"; - case GGML_TYPE_Q4_0: return "f16"; - case GGML_TYPE_Q8_0: return "f16"; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - return ""; - } + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; } inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( @@ -120,14 +76,23 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( std::vector defines; std::string variant = "flash_attn"; - defines.push_back(std::string("KV_TYPE=") + kv_shader_type(context.kv_type)); - variant += std::string("_") + ggml_type_name(context.kv_type); - - if (context.kv_type == GGML_TYPE_Q4_0) { - defines.push_back("KV_Q4_0"); - } else if (context.kv_type == GGML_TYPE_Q8_0) { - defines.push_back("KV_Q8_0"); + switch (context.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); } + variant += std::string("_") + ggml_type_name(context.kv_type); if (context.has_mask) { defines.push_back("MASK"); @@ -142,6 +107,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( variant += "_lgsc"; } + if (context.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk)); variant += std::string("_hsqk") + std::to_string(context.head_dim_qk); @@ -154,12 +124,15 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); // Add chosen Q/KV tile sizes - auto [q_tile, kv_tile] = ggml_webgpu_flash_attn_tile_sizes(context); + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); // workgroup size uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4204b41c518..6d2ca5c0212 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -272,14 +272,16 @@ struct flash_attn_pipeline_key { int dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; + bool kv_direct; bool has_mask; bool has_sinks; bool uses_logit_softcap; bool operator==(const flash_attn_pipeline_key & other) const { return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && has_mask == other.has_mask && - has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && + has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap; } }; @@ -296,6 +298,7 @@ struct flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); @@ -313,10 +316,10 @@ struct webgpu_context_struct { uint32_t max_subgroup_size; - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; std::recursive_mutex mutex; std::atomic_uint inflight_threads = 0; @@ -1006,12 +1009,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, #ifndef __EMSCRIPTEN__ if (ctx->supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = - WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; @@ -1113,12 +1114,15 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % ctx->sg_mat_n == 0); + flash_attn_pipeline_key key = { .q_type = Q->type, .kv_type = K->type, .dst_type = dst->type, .head_dim_qk = (uint32_t) Q->ne[0], .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, .has_mask = mask != nullptr, .has_sinks = sinks != nullptr, .uses_logit_softcap = logit_softcap != 0.0f, @@ -1141,12 +1145,13 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type, .head_dim_qk = (uint32_t) Q->ne[0], .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, .has_mask = mask != nullptr, .has_sinks = sinks != nullptr, .uses_logit_softcap = logit_softcap != 0.0f, - .sg_mat_m = ctx->sg_mat_m, - .sg_mat_n = ctx->sg_mat_n, - .sg_mat_k = ctx->sg_mat_k, + .sg_mat_m = ctx->sg_mat_m, + .sg_mat_n = ctx->sg_mat_n, + .sg_mat_k = ctx->sg_mat_k, .wg_mem_limit_bytes = ctx->limits.maxComputeWorkgroupStorageSize, .max_subgroup_size = ctx->max_subgroup_size }; @@ -2669,9 +2674,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const // Head dimensions must fit in workgroup memory with minimum tile sizes size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask); + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, + (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask); if (min_bytes > limit_bytes) { break; } @@ -2857,9 +2862,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->sg_mat_m = config.M; - ctx->sg_mat_n = config.N; - ctx->sg_mat_k = config.K; + ctx->sg_mat_m = config.M; + ctx->sg_mat_n = config.N; + ctx->sg_mat_k = config.K; valid_subgroup_matrix_config = true; break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 963a485d877..f6970ceeac3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,21 +4,27 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; -// Default values +#ifdef KV_F32 #define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +// Default values #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 -#define Q_TILE 16 -#define KV_TILE 16 -#define WG_SIZE 64 - // The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN // Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension. #define SG_MAT_M 8 #define SG_MAT_N 8 #define SG_MAT_K 8 +// Each workgroup processes one subgroup matrix of Q rows +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 + // Quantization constants/helpers #define BLOCK_SIZE 32 #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) @@ -37,6 +43,7 @@ enable chromium_experimental_subgroup_matrix; #endif #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) +// Ok not to put these in a define block, compiler will remove if unused fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; } @@ -114,9 +121,11 @@ const FLOAT_MIN: f16 = -65504.0; // The number of Q rows processed per workgroup var q_shmem: array; -// we can reuse the same shmem for K and V since we only need one at a time +#ifndef KV_DIRECT const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time var kv_shmem: array; +#endif var o_shmem: array; // output shmem @@ -149,8 +158,7 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f16) -> f16 { return v; } -// Number of blocks this workgroup handles at the subgroup matrix level. SG_MAT_M must divide Q_TILE. -const Q_BLOCKS = Q_TILE / SG_MAT_M; +// Q_TILE is assumed to match SG_MAT_M, so we process a single Q block per workgroup. // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. const KV_BLOCKS = KV_TILE / SG_MAT_N; @@ -274,6 +282,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } +#elif defined(KV_DIRECT) + // Direct global loads for KV #else for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { let k_row = elem_idx / params.head_dim_qk; @@ -289,36 +299,51 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); +#ifdef KV_DIRECT + let tile_kv = min(params.seq_len_kv - kv_tile, KV_TILE); +#else + let tile_kv: u32 = KV_TILE; +#endif + let valid_kv_blocks = tile_kv / SG_MAT_N; + // accumulate q block * k block into registers across the entire KV tile - for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { - let q_block_offset = sg_block * SG_MAT_M * params.head_dim_qk; - for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { - var acc: subgroup_matrix_result; - let k_block_offset = kv_block * SG_MAT_N * params.head_dim_qk; - for (var head_dim_block = 0u; head_dim_block < params.head_dim_qk; head_dim_block += SG_MAT_K) { - // load q submatrix from shared memory - var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( - &q_shmem, - q_block_offset + head_dim_block, - false, - params.head_dim_qk - ); + for (var kv_block = subgroup_id; kv_block < valid_kv_blocks; kv_block += num_subgroups) { + var acc: subgroup_matrix_result; + let k_block_offset = kv_block * SG_MAT_N * params.head_dim_qk; + for (var head_dim_block = 0u; head_dim_block < params.head_dim_qk; head_dim_block += SG_MAT_K) { + // load q submatrix from shared memory + var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + &q_shmem, + head_dim_block, + false, + params.head_dim_qk + ); - // load k submatrix from shared memory + // load k submatrix from device or shared memory +#ifdef KV_DIRECT + let k_block_row = kv_tile + kv_block * SG_MAT_N; + let k_global_offset = k_head_offset + k_block_row * params.stride_k1 + head_dim_block; + var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &K, + k_global_offset, + true, + params.stride_k1 + ); +#else var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &kv_shmem, k_block_offset + head_dim_block, true, params.head_dim_qk ); +#endif - acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); - } - - // store acc to shared memory for softmax (S matrix from paper) - let inter_offset = sg_block * SG_MAT_M * KV_TILE + kv_block * SG_MAT_N; - subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); } + + // store acc to shared memory for softmax (S matrix from paper) + let inter_offset = kv_block * SG_MAT_N; + subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); } #ifdef MASK @@ -339,56 +364,52 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); // online softmax - for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { - let block_row_start = sg_block * SG_MAT_M; - let block_row_end = block_row_start + SG_MAT_M; - for (var q_tile_row = block_row_start; q_tile_row < block_row_end; q_tile_row++) { - // no need to process rows beyond seq_len_q - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + var row_iter = q_tile_row / num_subgroups; - // initialize running max for this row - // only the first thread in the subgroup needs to read from shared memory. - // TODO: is this faster than having all threads read shared memory? - var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); - prev_max = subgroupBroadcastFirst(prev_max); - var final_max = prev_max; - - // pass 1: compute final max across the full KV tile in chunks - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); - final_max = subgroupMax(max(final_max, softmax_term)); - } + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } - var total_exp_term: f16 = 0.0; - - // pass 2: compute exp sum and write P using final_max - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); - let cur_p = select(0.0, - exp(softmax_term - final_max), - kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); - total_exp_term += subgroupAdd(cur_p); - if (kv_idx < KV_TILE) { - inter_shmem[kv_idx + q_tile_row * KV_TILE] = cur_p; - } - } + // initialize running max for this row + // only the first thread in the subgroup needs to read from shared memory. + // TODO: is this faster than having all threads read shared memory? + var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); + prev_max = subgroupBroadcastFirst(prev_max); + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + final_max = subgroupMax(max(final_max, softmax_term)); + } - let cur_exp = exp(prev_max - final_max); - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = final_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + var total_exp_term: f16 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = cur_p; } + } - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { - let idx = q_tile_row * params.head_dim_v + elem_idx; - o_shmem[idx] *= cur_exp; - } + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { + let idx = q_tile_row * params.head_dim_v + elem_idx; + o_shmem[idx] *= cur_exp; } } @@ -447,6 +468,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } +#elif defined(KV_DIRECT) + // Direct global loads for KV #else for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE) { let v_row = elem_idx / params.head_dim_v; @@ -464,19 +487,20 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V across the full KV tile - for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { - let o_row_offset = sg_block * SG_MAT_M * params.head_dim_v; - for (var head_dim_block = 0u; head_dim_block < params.head_dim_v; head_dim_block += SG_MAT_N) { + for (var head_dim_block = subgroup_id * SG_MAT_N; + head_dim_block < params.head_dim_v; + head_dim_block += num_subgroups * SG_MAT_N) { // load O submatrix from shared memory var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( &o_shmem, - o_row_offset + head_dim_block, + head_dim_block, false, params.head_dim_v ); - for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { - let p_offset = sg_block * SG_MAT_M * KV_TILE + kv_block * SG_MAT_N; + //for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { + for (var kv_block = 0u; kv_block < valid_kv_blocks; kv_block++) { + let p_offset = kv_block * SG_MAT_N; var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &inter_shmem, p_offset, @@ -484,7 +508,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, KV_TILE ); - // load V submatrix from shared memory + // load V submatrix from global or shared memory +#ifdef KV_DIRECT + let v_block_row = kv_tile + kv_block * SG_MAT_N; + let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + &V, + v_global_offset, + false, + params.stride_v1 + ); +#else let v_block_offset = kv_block * SG_MAT_N * params.head_dim_v; var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &kv_shmem, @@ -492,14 +526,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3, false, params.head_dim_v ); +#endif // O += P * V o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); } // store O back to shared memory - subgroupMatrixStore(&o_shmem, o_row_offset + head_dim_block, o_sg_mat, false, params.head_dim_v); - } + subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, params.head_dim_v); } workgroupBarrier(); @@ -507,10 +541,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef SINKS // add sinks (applied once after processing all KV tiles) - for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { - let block_row_start = sg_block * SG_MAT_M; - let block_row_end = block_row_start + SG_MAT_M; - for (var q_tile_row = block_row_start; q_tile_row < block_row_end; q_tile_row++) { + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { // no need to process rows beyond seq_len_q let global_q_row = q_row_start + q_tile_row; if (global_q_row >= params.seq_len_q) { @@ -537,17 +570,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let val = o_shmem[idx] * max_exp; o_shmem[idx] = val; } - } } workgroupBarrier(); #endif // write output back to global memory - for (var sg_block = subgroup_id; sg_block < Q_BLOCKS; sg_block += num_subgroups) { - let block_row_start = sg_block * SG_MAT_M; - let block_row_end = block_row_start + SG_MAT_M; - for (var q_tile_row = block_row_start; q_tile_row < block_row_end; q_tile_row++) { + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { let global_q_row = q_row_start + q_tile_row; if (global_q_row >= params.seq_len_q) { break; @@ -563,6 +594,5 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let scaled = o_val * scale; dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = f32(scaled); } - } } } From 4070a04954792a1ea4f62024fd36e915dccda6ec Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sat, 3 Jan 2026 11:39:07 -0800 Subject: [PATCH 39/40] Checkpoint --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 19 +++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 13 +-- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 106 ++++++++---------- 3 files changed, 68 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index aad88d8a3d3..e8634a2e628 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -43,11 +43,14 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, uint32_t head_dim_qk, uint32_t head_dim_v, - bool has_mask) { + bool has_mask, + bool kv_direct) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t elems = 0; elems += q_tile * head_dim_qk; // q_shmem - elems += kv_tile * max_head_dim; // kv_shmem + if (!kv_direct) { + elems += kv_tile * max_head_dim; // kv_shmem + } elems += q_tile * head_dim_v; // o_shmem if (has_mask) { elems += q_tile * kv_tile; // mask_shmem @@ -62,9 +65,15 @@ static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_ const size_t limit_bytes = context.wg_mem_limit_bytes; const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v + 2) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES; - const size_t bytes_per_kv = - (std::max(context.head_dim_qk, context.head_dim_v) + (context.has_mask ? q_tile : 0) + q_tile) * - GGML_WEBGPU_F16_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.kv_direct) { + bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v); + } + if (context.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 6d2ca5c0212..f76ae85d0ff 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -759,11 +759,6 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) GGML_UNUSED(ctx); #endif - - for (auto & kv : ctx->webgpu_ctx->flash_attn_pipelines) { - delete static_cast(kv.second.context); - kv.second.context = nullptr; - } } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -1058,8 +1053,6 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) Q->ne[0], // head dimension (Q/K) - (uint32_t) V->ne[0], // head dimension (V) (uint32_t) Q->ne[2], // number of heads (uint32_t) Q->ne[1], // sequence length (Q) (uint32_t) K->ne[1], // sequence length (K/V) @@ -2674,9 +2667,13 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const // Head dimensions must fit in workgroup memory with minimum tile sizes size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && + (src1->ne[1] % webgpu_ctx->sg_mat_n) == 0; const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, - (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask); + (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + kv_direct); if (min_bytes > limit_bytes) { break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index f6970ceeac3..5e00d00ebbc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -61,8 +61,6 @@ struct Params { offset_dst: u32, // shapes of Q/K/V - head_dim_qk: u32, - head_dim_v: u32, n_heads: u32, seq_len_q: u32, seq_len_kv: u32, @@ -160,6 +158,7 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f16) -> f16 { // Q_TILE is assumed to match SG_MAT_M, so we process a single Q block per workgroup. // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. +// TODO: if this can be used instead of valid_kv_blocks, performance increases const KV_BLOCKS = KV_TILE / SG_MAT_N; @compute @workgroup_size(WG_SIZE) @@ -179,7 +178,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; let wg_per_batch = wg_per_head * params.n_heads; - let dst2_stride = params.head_dim_v * params.n_heads; + let dst2_stride = HEAD_DIM_V * params.n_heads; let dst3_stride = dst2_stride * params.seq_len_q; // batch index @@ -208,34 +207,34 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #endif // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size] - let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * params.head_dim_v; + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; let head = f32(head_idx); let slope = f16(select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0)); // load q tile into shared memory - for (var elem_idx = local_id.x; elem_idx < Q_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { - let q_row = elem_idx / params.head_dim_qk; - let q_col = elem_idx % params.head_dim_qk; + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; let head_q_row = q_row_start + q_row; let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; q_shmem[elem_idx] = f16(select( 0.0, Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < params.head_dim_qk)); + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { // load k tile into shared memory #if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE * NQ) { + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { let blck_idx = elem_idx / BLOCK_SIZE; let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; let k_row = blck_idx / BLOCKS_K; let global_k_row = kv_tile + k_row; let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * params.head_dim_qk; + let row_offset = k_row * HEAD_DIM_QK; if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; @@ -257,13 +256,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } #elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE * NQ) { + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { let blck_idx = elem_idx / BLOCK_SIZE; let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; let k_row = blck_idx / BLOCKS_K; let global_k_row = kv_tile + k_row; let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * params.head_dim_qk; + let row_offset = k_row * HEAD_DIM_QK; if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; @@ -285,15 +284,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #elif defined(KV_DIRECT) // Direct global loads for KV #else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_qk; elem_idx += WG_SIZE) { - let k_row = elem_idx / params.head_dim_qk; - let k_col = elem_idx % params.head_dim_qk; + for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; let global_k_row = kv_tile + k_row; let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; kv_shmem[elem_idx] = f16(select( 0.0, K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < params.head_dim_qk)); + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); } #endif @@ -307,25 +306,30 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let valid_kv_blocks = tile_kv / SG_MAT_N; // accumulate q block * k block into registers across the entire KV tile + // TODO: this loop seems to be the current largest bottleneck for (var kv_block = subgroup_id; kv_block < valid_kv_blocks; kv_block += num_subgroups) { var acc: subgroup_matrix_result; - let k_block_offset = kv_block * SG_MAT_N * params.head_dim_qk; - for (var head_dim_block = 0u; head_dim_block < params.head_dim_qk; head_dim_block += SG_MAT_K) { +#ifdef KV_DIRECT + let k_block_row = kv_tile + kv_block * SG_MAT_N; + let k_global_offset = k_head_offset + k_block_row * params.stride_k1; +#else + let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK; +#endif + let inter_offset = kv_block * SG_MAT_N; + for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) { // load q submatrix from shared memory var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &q_shmem, head_dim_block, false, - params.head_dim_qk + HEAD_DIM_QK ); // load k submatrix from device or shared memory #ifdef KV_DIRECT - let k_block_row = kv_tile + kv_block * SG_MAT_N; - let k_global_offset = k_head_offset + k_block_row * params.stride_k1 + head_dim_block; var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &K, - k_global_offset, + k_global_offset + head_dim_block, true, params.stride_k1 ); @@ -334,15 +338,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, &kv_shmem, k_block_offset + head_dim_block, true, - params.head_dim_qk + HEAD_DIM_QK ); #endif - acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); } // store acc to shared memory for softmax (S matrix from paper) - let inter_offset = kv_block * SG_MAT_N; subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); } @@ -365,18 +367,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // online softmax for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - var row_iter = q_tile_row / num_subgroups; - let global_q_row = q_row_start + q_tile_row; if (global_q_row >= params.seq_len_q) { break; } // initialize running max for this row - // only the first thread in the subgroup needs to read from shared memory. - // TODO: is this faster than having all threads read shared memory? - var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); - prev_max = subgroupBroadcastFirst(prev_max); + var prev_max = row_max_shmem[q_tile_row]; var final_max = prev_max; // pass 1: compute final max across the full KV tile in chunks for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { @@ -385,7 +382,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, final_max = subgroupMax(max(final_max, softmax_term)); } - var total_exp_term: f16 = 0.0; // pass 2: compute exp sum and write P using final_max for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { @@ -407,21 +403,21 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; } - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { - let idx = q_tile_row * params.head_dim_v + elem_idx; + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; o_shmem[idx] *= cur_exp; } } // load v tile into shared memory #if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE * NQ) { + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { let blck_idx = elem_idx / BLOCK_SIZE; let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; let v_row = blck_idx / BLOCKS_V; let global_v_row = kv_tile + v_row; let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * params.head_dim_v; + let row_offset = v_row * HEAD_DIM_V; if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; @@ -443,13 +439,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } #elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE * NQ) { + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { let blck_idx = elem_idx / BLOCK_SIZE; let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; let v_row = blck_idx / BLOCKS_V; let global_v_row = kv_tile + v_row; let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * params.head_dim_v; + let row_offset = v_row * HEAD_DIM_V; if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; @@ -471,15 +467,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #elif defined(KV_DIRECT) // Direct global loads for KV #else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * params.head_dim_v; elem_idx += WG_SIZE) { - let v_row = elem_idx / params.head_dim_v; - let v_col = elem_idx % params.head_dim_v; + for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; let global_v_row = kv_tile + v_row; let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; kv_shmem[elem_idx] = f16(select( 0.0, V[global_v_row_offset + v_col], - global_v_row < params.seq_len_kv && v_col < params.head_dim_v)); + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); } #endif @@ -488,17 +484,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V across the full KV tile for (var head_dim_block = subgroup_id * SG_MAT_N; - head_dim_block < params.head_dim_v; + head_dim_block < HEAD_DIM_V; head_dim_block += num_subgroups * SG_MAT_N) { // load O submatrix from shared memory var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( &o_shmem, head_dim_block, false, - params.head_dim_v + HEAD_DIM_V ); - //for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { for (var kv_block = 0u; kv_block < valid_kv_blocks; kv_block++) { let p_offset = kv_block * SG_MAT_N; var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( @@ -519,12 +514,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, params.stride_v1 ); #else - let v_block_offset = kv_block * SG_MAT_N * params.head_dim_v; + let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &kv_shmem, v_block_offset + head_dim_block, false, - params.head_dim_v + HEAD_DIM_V ); #endif @@ -533,7 +528,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } // store O back to shared memory - subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, params.head_dim_v); + subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V); } workgroupBarrier(); @@ -550,8 +545,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, break; } - var prev_max = select(0.0, row_max_shmem[q_tile_row], sg_inv_id == 0); - prev_max = subgroupBroadcastFirst(prev_max); + var prev_max = row_max_shmem[q_tile_row]; // for non-sink threads, exp(-65504) effectively zeroes out their contribution to the sum let sink_val = select(FLOAT_MIN, f16(sinks[params.offset_sinks + head_idx]), sg_inv_id == 0); @@ -565,8 +559,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; } - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { - let idx = q_tile_row * params.head_dim_v + elem_idx; + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; let val = o_shmem[idx] * max_exp; o_shmem[idx] = val; } @@ -584,13 +578,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, break; } - var exp_sum = select(0.0, exp_sum_shmem[q_tile_row], sg_inv_id == 0); - exp_sum = subgroupBroadcastFirst(exp_sum); - + let exp_sum = exp_sum_shmem[q_tile_row]; let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0); - for (var elem_idx = sg_inv_id; elem_idx < params.head_dim_v; elem_idx += subgroup_size) { - let o_val = o_shmem[q_tile_row * params.head_dim_v + elem_idx]; + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx]; let scaled = o_val * scale; dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = f32(scaled); } From f71815f5a29548875fb2d2c6f02b63450615528b Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Sat, 3 Jan 2026 16:51:42 -0800 Subject: [PATCH 40/40] formatting --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 14 +++++++------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 16 +++++++--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e8634a2e628..ff0aef36a99 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -47,17 +47,17 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, bool kv_direct) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t elems = 0; - elems += q_tile * head_dim_qk; // q_shmem + elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { elems += kv_tile * max_head_dim; // kv_shmem } - elems += q_tile * head_dim_v; // o_shmem + elems += q_tile * head_dim_v; // o_shmem if (has_mask) { - elems += q_tile * kv_tile; // mask_shmem + elems += q_tile * kv_tile; // mask_shmem } - elems += q_tile * kv_tile; // inter_shmem - elems += q_tile; // row_max_shmem - elems += q_tile; // exp_sum_shmem + elems += q_tile * kv_tile; // inter_shmem + elems += q_tile; // row_max_shmem + elems += q_tile; // exp_sum_shmem return elems * GGML_WEBGPU_F16_SIZE_BYTES; } @@ -65,7 +65,7 @@ static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_ const size_t limit_bytes = context.wg_mem_limit_bytes; const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v + 2) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!context.kv_direct) { bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v); } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f76ae85d0ff..ccc09d83778 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2665,15 +2665,13 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && - (src1->ne[1] % webgpu_ctx->sg_mat_n) == 0; - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, - (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - kv_direct); + size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && + (src1->ne[1] % webgpu_ctx->sg_mat_n) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], + has_mask, kv_direct); if (min_bytes > limit_bytes) { break; }