Skip to content

ggml-webgpu: FlashAttention refactor + standardize quantization support#23834

Merged
ggerganov merged 9 commits into
ggml-org:masterfrom
reeselevine:flash_attn_refactor
Jun 4, 2026
Merged

ggml-webgpu: FlashAttention refactor + standardize quantization support#23834
ggerganov merged 9 commits into
ggml-org:masterfrom
reeselevine:flash_attn_refactor

Conversation

@reeselevine
Copy link
Copy Markdown
Contributor

@reeselevine reeselevine commented May 28, 2026

Overview

With three separate FlashAttention paths depending on sequence length and device capability, the code was getting messy. Quantized KV-caches also weren't supported by the tile path, which means that quantized KV-caches wouldn't run in WebGPU in the browser. This PR does a number of refactors to clean up the paths and add the same quantized KV-cache functionality everywhere:

Support and Setup Refactors

In ggml-webgpu.cpp:

  • supports_op: checks only whether the sg_matrix or tile shader paths will work. This is because the auto FlashAttention setting uses a sequence length of 1 to probe support, but we want to ensure that FlashAttention will also work for larger sequence lengths, e.g., during prefill. Otherwise, we may end up in scenarios where the FlashAttention tensor used at runtime (with a larger sequence length then the initial check) can't fit on the GPU and runs on the CPU instead, which would be slower than not using FlashAttention to begin with.
  • get_alloc_size: First checks if the tensor will run using the vec path, because only that path needs extra buffer space. No longer relies on the full decisions code logic for the FlashAttention shaders.
  • during graph_compute: Does a common setup (ggml_webgpu_flash_attn_prepare), and then splits into vec/non-vec specific paths, instead of mixing these together.

In ggml-webgpu-shader-lib.cpp:

  • Splits the vec/non-vec paths into separate cached pipelines and creation functions, since the combined path ended up mixing information and required a bunch of conditionals.
  • Separates out functions required for supports_op and splitting into vec/non-vec paths, so that they can be called from ggml-webgpu.cpp without also dragging along a bunch of decision information that isn't needed. This also allows the decision information to be computed once when needed during pipeline creation.

In the shaders:

  • flash_attn_quant_staging.tmpl is added, which includes common helpers and dequantization to shared memory that is used by all of the FlashAttention shaders.
  • quant_inner_loops.tmpl is added, which abstracts some of the common dequantization code between mul_mat and flash_attn. This format can be followed for future sharing of code between the two operations too.

New functionality

  • pre_wgsl.hpp is updated to allow multi-line macros, used for the shared dequantization to shared memory in the FlashAttention shaders.
  • The K and V cache formats are no longer coupled. For example, f16 can be used for one and q8_0 for another, which wasn't possible previously.
  • flash_attn_tile.wgsl, which is the long sequence length FlashAttention implementation used in the browser, now supports q4_0 and q8_0 KV-cache quantization.

I've verified that using quantized KV-caches works with this change on the command line and in Chrome on an Apple M3. Safari does not support subgroups, which are required for the WebGPU FlashAttention shaders.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: yes, for some of the mechanical refactors. Design was all done by me.

@github-actions github-actions Bot added ggml changes relating to the ggml tensor library for machine learning WebGPU labels May 28, 2026
@reeselevine reeselevine marked this pull request as ready for review June 2, 2026 04:46
@reeselevine reeselevine requested a review from a team as a code owner June 2, 2026 04:46
@reeselevine
Copy link
Copy Markdown
Contributor Author

Any reviews from other WebGPU contributors are appreciated :)

@yomaytk @Constannnnnt @ArberSephirotheca

}

inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, const ggml_tensor * K, uint32_t kv_direct_align) {
return K->type == GGML_TYPE_F16 && (Q->ne[0] % std::max(1u, kv_direct_align) == 0) &&
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking only K type was enough before this PR, but with K/V
decoupled in this PR, I think that ggml_webgpu_flash_attn_kv_direct needs to check V type too.

For example, the following test case currently fails, but passes after adding V type checking:
test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 1}, 256, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0));

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep good catch, fixed this and added that test case too

}

inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) {
const size_t alignment = std::max<size_t>(1u, storage_offset_alignment);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to this PR change, but is std::max needed here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, removed

}

#define LOAD_K_Q4_0_TILE_BLOCK \
for (var elem_idx = local_id.x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sharing the logic between multiple flash_attn wgsl files is nice! Also, the multi-line macro support would be useful for several cases. But any reason to use macros instead of functions here? Personally I'd prefer functions, but wdyt?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried there would be a lot of arguments, but turns out that's not the case, so converted to functions. Was also able to remove some more code duplication.

I'll leave the multi-line macro ability in pre-wgsl though. Sometimes function calls, especially in hot inner loops, might reduce performance, so could be useful to use them in other places at some point.

@reeselevine reeselevine requested a review from ggerganov as a code owner June 2, 2026 16:06
@github-actions github-actions Bot added the testing Everything test related label Jun 2, 2026
@Constannnnnt
Copy link
Copy Markdown
Contributor

Thanks! These changes look good to me.

Copy link
Copy Markdown
Contributor

@yomaytk yomaytk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good work, thanks! looks good to me.

@reeselevine reeselevine added the merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. label Jun 3, 2026
@ggerganov ggerganov merged commit e8c5489 into ggml-org:master Jun 4, 2026
27 of 28 checks passed
gabe-l-hart added a commit to gabe-l-hart/llama.cpp that referenced this pull request Jun 4, 2026
* origin/master: (57 commits)
server : disable on-device spec checkpoints (ggml-org#24108)
arg: fix double mtp downloads (ggml-org#24128)
webui: [a11y] fix keyboard navigation issues in chat interface and sidebar (ggml-org#23132)
Move duplicated imatrix code into single common imatrix-loader.cpp (ggml-org#22445)
ui: Fixed packages (ggml-org#24119)
ui: added single line reasoning preview (ggml-org#23601)
return filter to save memory (ggml-org#24125)
convert: Fix Gemma 4 Unified conversion (ggml-org#24118)
ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 (ggml-org#22209)
server: avoid unnecessary checkpoint restore when new tokens are present (ggml-org#24110)
agents: refactor, include more guidelines (ggml-org#24111)
webui: fix tool selector toggle/counter, key tools by stable identity (ggml-org#24065)
build : use umbrella Headers directory for XCFramework module map (ggml-org#23974)
server : add header to tools/server/server-http.h (ggml-org#24089)
cmake: skip cvector-generator and export-lora when CPU backend is disabled (ggml-org#24053)
fix(mtmd): handle Gemma 4 audio projector embedding size (ggml-org#24091)
readme : add status badges (ggml-org#24104)
tests : refactor test-save-load-state to accept token input (ggml-org#24073)
metal : reduce rset heartbeat from 500ms -> 5ms (ggml-org#24074)
ggml-webgpu: FlashAttention refactor + standardize quantization support (ggml-org#23834)
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning merge ready A maintainer can use this label to indicate that they consider the changes final and ready to merge. testing Everything test related WebGPU

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants