ggml-webgpu: FlashAttention refactor + standardize quantization support#23834
Conversation
|
Any reviews from other WebGPU contributors are appreciated :) |
| } | ||
|
|
||
| inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, const ggml_tensor * K, uint32_t kv_direct_align) { | ||
| return K->type == GGML_TYPE_F16 && (Q->ne[0] % std::max(1u, kv_direct_align) == 0) && |
There was a problem hiding this comment.
Checking only K type was enough before this PR, but with K/V
decoupled in this PR, I think that ggml_webgpu_flash_attn_kv_direct needs to check V type too.
For example, the following test case currently fails, but passes after adding V type checking:
test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 1}, 256, 1, false, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0));
There was a problem hiding this comment.
yep good catch, fixed this and added that test case too
| } | ||
|
|
||
| inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { | ||
| const size_t alignment = std::max<size_t>(1u, storage_offset_alignment); |
There was a problem hiding this comment.
This is not related to this PR change, but is std::max needed here?
| } | ||
|
|
||
| #define LOAD_K_Q4_0_TILE_BLOCK \ | ||
| for (var elem_idx = local_id.x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { \ |
There was a problem hiding this comment.
Sharing the logic between multiple flash_attn wgsl files is nice! Also, the multi-line macro support would be useful for several cases. But any reason to use macros instead of functions here? Personally I'd prefer functions, but wdyt?
There was a problem hiding this comment.
I was worried there would be a lot of arguments, but turns out that's not the case, so converted to functions. Was also able to remove some more code duplication.
I'll leave the multi-line macro ability in pre-wgsl though. Sometimes function calls, especially in hot inner loops, might reduce performance, so could be useful to use them in other places at some point.
|
Thanks! These changes look good to me. |
yomaytk
left a comment
There was a problem hiding this comment.
good work, thanks! looks good to me.
* origin/master: (57 commits) server : disable on-device spec checkpoints (ggml-org#24108) arg: fix double mtp downloads (ggml-org#24128) webui: [a11y] fix keyboard navigation issues in chat interface and sidebar (ggml-org#23132) Move duplicated imatrix code into single common imatrix-loader.cpp (ggml-org#22445) ui: Fixed packages (ggml-org#24119) ui: added single line reasoning preview (ggml-org#23601) return filter to save memory (ggml-org#24125) convert: Fix Gemma 4 Unified conversion (ggml-org#24118) ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 (ggml-org#22209) server: avoid unnecessary checkpoint restore when new tokens are present (ggml-org#24110) agents: refactor, include more guidelines (ggml-org#24111) webui: fix tool selector toggle/counter, key tools by stable identity (ggml-org#24065) build : use umbrella Headers directory for XCFramework module map (ggml-org#23974) server : add header to tools/server/server-http.h (ggml-org#24089) cmake: skip cvector-generator and export-lora when CPU backend is disabled (ggml-org#24053) fix(mtmd): handle Gemma 4 audio projector embedding size (ggml-org#24091) readme : add status badges (ggml-org#24104) tests : refactor test-save-load-state to accept token input (ggml-org#24073) metal : reduce rset heartbeat from 500ms -> 5ms (ggml-org#24074) ggml-webgpu: FlashAttention refactor + standardize quantization support (ggml-org#23834) ...
Overview
With three separate FlashAttention paths depending on sequence length and device capability, the code was getting messy. Quantized KV-caches also weren't supported by the
tilepath, which means that quantized KV-caches wouldn't run in WebGPU in the browser. This PR does a number of refactors to clean up the paths and add the same quantized KV-cache functionality everywhere:Support and Setup Refactors
In
ggml-webgpu.cpp:supports_op: checks only whether thesg_matrixortileshader paths will work. This is because theautoFlashAttention setting uses a sequence length of 1 to probe support, but we want to ensure that FlashAttention will also work for larger sequence lengths, e.g., during prefill. Otherwise, we may end up in scenarios where the FlashAttention tensor used at runtime (with a larger sequence length then the initial check) can't fit on the GPU and runs on the CPU instead, which would be slower than not using FlashAttention to begin with.get_alloc_size: First checks if the tensor will run using thevecpath, because only that path needs extra buffer space. No longer relies on the full decisions code logic for the FlashAttention shaders.graph_compute: Does a common setup (ggml_webgpu_flash_attn_prepare), and then splits into vec/non-vec specific paths, instead of mixing these together.In
ggml-webgpu-shader-lib.cpp:supports_opand splitting into vec/non-vec paths, so that they can be called fromggml-webgpu.cppwithout also dragging along a bunch of decision information that isn't needed. This also allows the decision information to be computed once when needed during pipeline creation.In the shaders:
flash_attn_quant_staging.tmplis added, which includes common helpers and dequantization to shared memory that is used by all of the FlashAttention shaders.quant_inner_loops.tmplis added, which abstracts some of the common dequantization code betweenmul_matandflash_attn. This format can be followed for future sharing of code between the two operations too.New functionality
pre_wgsl.hppis updated to allow multi-line macros, used for the shared dequantization to shared memory in the FlashAttention shaders.f16can be used for one andq8_0for another, which wasn't possible previously.flash_attn_tile.wgsl, which is the long sequence length FlashAttention implementation used in the browser, now supportsq4_0andq8_0KV-cache quantization.I've verified that using quantized KV-caches works with this change on the command line and in Chrome on an Apple M3. Safari does not support subgroups, which are required for the WebGPU FlashAttention shaders.
Requirements