Skip to content

UPSTREAM PR #18610: ggml webgpu: initial flashattention implementation#821

Open
loci-dev wants to merge 3 commits intomainfrom
upstream-PR18610-branch_reeselevine-master
Open

UPSTREAM PR #18610: ggml webgpu: initial flashattention implementation#821
loci-dev wants to merge 3 commits intomainfrom
upstream-PR18610-branch_reeselevine-master

Conversation

@loci-dev
Copy link

@loci-dev loci-dev commented Jan 5, 2026

Mirrored from ggml-org/llama.cpp#18610

This PR adds an initial version of FlashAttention2 in WebGPU. Along with the GPU code itself, this PR also adds a new preprocessor for WGSL shaders that should make it easier/less brittle to define new shaders going forwards. Details below:

Shader setup

  • Most of the shaders right now are generated at build time using a relatively hacky Python script and template syntax that I wrote when doing initial development for the WebGPU backend. However, this probably won't continue to scale very well, especially with the number of options for FlashAttention, so I decided it was time to use a more general solution. However, there wasn't an existing preprocessor for WGSL that would work with C++ code, so I wrote one here: https://github.com/reeselevine/pre-wgsl. The preprocessor itself is one file, pre_wgsl.hpp, and should continue to track any changes/features added to the main preprocessor repository.
  • To accommodate the various options for FlashAttention when compiling WGSL shaders, I added another new file to the WebGPU backend, ggml-webgpu-shader-lib.hpp, which generates the shader using a combination of structural parameters, e.g., head sizes, and performance parameters, like KV tile sizes. My idea is to expand this library approach use more sophisticated strategies going forwards and move other shaders over to use the preprocessor. From a performance perspective, shaders compile pretty fast, are mostly fixed for a given model, and are cached, so JIT compilation, at least for FlashAttention, seems to be the right call in my opinion.
  • Some other minor changes to ggml-webgpu.cpp to handle the new FlashAttention code and JIT compilation.

FlashAttention shader itself

  • For the most part this follows the FlashAttention2 paper, with a change for online softmax to make it subgroup-size agnostic. It also uses global KV loads if sizes are nicely divisible, since the KV tiles are not reused and pre-loading into shared memory really slows things down (at least on my M3).
  • Even then, performance is not great right now (< 50% of the same Metal code). My testing shows that a lot of this slowdown basically boils down to the initial Q * K^T accumulation loop. I still need to do some debugging here, to figure out if the issue is something I'm doing wrong structurally, or if it's in the compilation from WGSL to Metal. This code also needs to be more thoroughly tested on other platforms. Perhaps someone who has written FlashAttention for one of the other backends could take a look at the shader and see if it looks reasonable, e.g., @jeffbolznv?
  • Otherwise, this passes all the backend tests on my machine, so I think it's in a good state to merge as an initial implementation that can be improved upon as time goes on.

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

* neg passes backend test

* unary operators pass ggml tests

* rms_norm double declaration bug atoned

* abides by editor-config

* removed vestigial files

* fixed autoconfig

* All operators (inlcluding xielu) working

* removed unnecesarry checking if node->src[1] exists for unary operators

* responded and dealt with PR comments

* implemented REPL_Template support and removed bug in unary operators kernel

* formatted embed wgsl and ggml-webgpu.cpp

* Faster tensors (#8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

* Wasm (#9)

* webgpu : fix build on emscripten

* more debugging stuff

* test-backend-ops: force single thread on wasm

* fix single-thread case for init_tensor_uniform

* use jspi

* add pthread

* test: remember to set n_thread for cpu backend

* Add buffer label and enable dawn-specific toggles to turn off some checks

* Intermediate state

* Fast working f16/f32 vec4

* Working float fast mul mat

* Clean up naming of mul_mat to match logical model, start work on q mul_mat

* Setup for subgroup matrix mat mul

* Basic working subgroup matrix

* Working subgroup matrix tiling

* Handle weirder sg matrix sizes (but still % sg matrix size)

* Working start to gemv

* working f16 accumulation with shared memory staging

* Print out available subgroup matrix configurations

* Vectorize dst stores for sg matrix shader

* Gemv working scalar

* Minor set_rows optimization (#4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan>
Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Comment on dawn toggles

* Working subgroup matrix code for (semi)generic sizes

* Remove some comments

* Cleanup code

* Update dawn version and move to portable subgroup size

* Try to fix new dawn release

* Update subgroup size comment

* Only check for subgroup matrix configs if they are supported

* Add toggles for subgroup matrix/f16 support on nvidia+vulkan

* Make row/col naming consistent

* Refactor shared memory loading

* Move sg matrix stores to correct file

* Working q4_0

* Formatting

* Work with emscripten builds

* Fix test-backend-ops emscripten for f16/quantized types

* Use emscripten memory64 to support get_memory

* Add build flags and try ci

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>

* Remove extra whitespace

* Move wasm single-thread logic out of test-backend-ops for cpu backend

* Disable multiple threads for emscripten single-thread builds in ggml_graph_plan

* Refactored pipelines and workgroup calculations (#10)

* refactored pipelines

* refactored workgroup calculation

* removed commented out block of prior maps

* Clean up ceiling division pattern

---------

Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>

* Start work on flash attention

* Shader structure set up (many bugs still)

* debugging

* Working first test

* Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32

* Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling

* Start work on integrating pre-wgsl

* Separate structs/initial shader compilation library into separate files

* Work on compilation choices for flashattention

* Work on subgroup matrix/tile size portability

* subgroup size agnostic online softmax

* Cleanups, quantization types

* more cleanup

* fix wasm build

* Refactor flashattention to increase parallelism, use direct loads for KV in somce cases

* Checkpoint

* formatting
@loci-review
Copy link

loci-review bot commented Jan 5, 2026

Explore the complete analysis inside the Version Insights

Perfect! I've generated the summary report for your project. Here's what the analysis shows:

Key Findings:

✅ No Significant Performance Impact Detected

For the llama.cpp repository (Pull Request #821 by auroralabs-loci):

  • Response Time: No modified functions showed changes greater than 2%
  • Throughput: No modified functions showed changes greater than 2%

This indicates that the pull request is performance-neutral and safe to merge from a performance perspective. The changes appear to focus on functionality, bug fixes, or code quality improvements without affecting the performance characteristics of the codebase.

The comparison was made between:

  • Base version: a7cb4ab1-e9f6-11f0-81f2-dbb430499cb5
  • Target version: a6976c51-e9fb-11f0-81f2-dbb430499cb5

@loci-dev loci-dev force-pushed the main branch 25 times, most recently from 6f813dc to f85d458 Compare January 8, 2026 07:12
@loci-dev loci-dev force-pushed the main branch 30 times, most recently from 2d2b258 to 78dd122 Compare January 14, 2026 10:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants