ggml-webgpu: address precision issues for multimodal #22808
Conversation
…calculation logic for f32
There was a problem hiding this comment.
Looks like a deep investigation, cool to see how much better the description is with the changes. The PR doesn't say, I'm guessing this is on an NVIDIA GPU?
I do have a couple of high-level comments beyond the minor comments on the code.
One that I think needs to be addressed in this PR:
- The idea behind the vec path for flash attention is that it should increase performance during typical decode scenarios, when the Q sequence length is only 1. But if I'm understanding correctly, this PR changes to always prefer the tile path over the vec path, which I think means the vec path will basically never run?
- The way this is supposed to work is that vec should be the priority if sequence length is 1, then subgroup matrix should be preferred to tiling if sequence length is > 1 and subgroup matrices are supported. This PR inverts some of that ordering. I realize the subgroup matrix path in particular may have precision issues, which I address in my point below. But I think we should strive to keep the priority in performance order if at all possible, and make adjustments to the shaders for precision to really try and maintain that order.
One that might not need to be addressed in this PR, but which I think will be important moving forward:
- The logic around path selection, tile size, and decisions for flash attention has gotten quite complicated, and I think this is at least in part due to the mixing of pipelines for the vec and non-vec pipelines. I think if we split up the logic similar to how it is done for matrix-matrix vs. matrix-vector multiplication, that will make the code much clearer and allow for easier changes moving forward.
One that should not be addressed in this PR but we should think about:
- Precision of intermediate states clearly can make a large difference depending on the device and model. For example, it's not clear that the subgroup matrix path will even work satisfactorily on many devices, because it computes in f16 precision. It's also not clear that f32 is really needed everywhere, but today you pay the memory overhead/performance cost of it no matter what.
- Ideally, the precision would be chosen dynamically in a way that maximizes speed and stability across devices, but I realize this is probably a larger (research?) project. But I think it's worth keeping in mind as we make changes.
@ArberSephirotheca for visibility, and also if you have thoughts or comments on this PR since you wrote the vec and tile attention shaders.
| 0.044715 * src[params.offset_src + src_idx] * | ||
| src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), | ||
| -9.010913, 9.010913))); | ||
| let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.70710678)); |
There was a problem hiding this comment.
any reason the multiplication constant 0.707 is in the call to erf_approx rather than within it?
There was a problem hiding this comment.
I follow the equation on most of the websites, e.g., https://alaaalatif.github.io/2019-04-11-gelu/, the 1/sqrt(2) is outside of the erf function.
There was a problem hiding this comment.
I noticed that the ci did not pass:
Error while parsing WGSL: :64:74 error: type mismatch for argument 1 in call to 'erf_approx', expected 'f32', got 'f16' let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.70710678));
I will correct this.
| uint32_t q_tile; | ||
| uint32_t kv_tile; | ||
| uint32_t wg_size; | ||
| uint32_t min_subgroup_size; |
There was a problem hiding this comment.
are these subgroup and the following sg_mat fields necessary in the key? They are fixed for a given WebGPU device so I don't think they should affect which pipeline is chosen, at least in principle.
Although they might be just proxying whether subgroups/subgroup matrices are are actually supported?
| inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; | ||
| inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; | ||
|
|
||
| inline uint32_t ggml_webgpu_effective_min_subgroup_size(const ggml_webgpu_shader_lib_context & context) { |
There was a problem hiding this comment.
why are these functions needed? If subgroups are supported, I think we can assume that the reported subgroup size is > 0, otherwise that seems like a bug in WebGPU itself.
There was a problem hiding this comment.
The initial motivation was to use min subgroup for calculating the per-lane arrays in tile, i.e. line 122-123 in flash_attn_tile.wgsl. because I am not familiar with this, so I did some research about min/max numbers on this topic, and I found this reference on different hardware: https://docs.rs/wgpu/latest/wgpu/struct.AdapterInfo.html. The 0 here is simply for a check. You can notice that these two functions are basically "cross-fallback" logic.
I might overthink this; it seems reasonable to remove them after second thoughts.
|
Thank you for these insights!! I think I can answer the first question on vec or tile path; actually, before this fix, my workaround entirely relied on the vec path because the vec path did not have the partial kv problem, as it does not need subgroup support (correct me if I am wrong). As discussed with @ArberSephirotheca in #22199, in some of my test cases, the sequence length of a multimodal request is sometimes longer; I thought it would be good to know what happened within the tile path and reverted the order for testing purposes. I did not notice the design of the performance priority, but now I get it. And one question here: why is |
|
yeah sorry I was misremembering the vec path, it works for sequence lengths (of Q) greater than 1 as well. Hopefully the performance of the vec path is faster in the short sequence cases, if it's not on any of the machines you're testing we might want to revisit it's design :). But thanks for the fix on the tile path, always nice to fix bugs and fix stability issues. |
|
I updated the unary shader using the editorconfig, removed the inline functions and redundant pipeline keys, and reverted the flash attn path order. |
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
|
@Constannnnnt Sorry, I missed this discussion earlier. The changes look great overall! I have one question about the tile-path precision change: after the I’m trying to understand whether |
|
@ArberSephirotheca Actually, I am not sure about this, as I did not test this. I can test this by this week and share the results with you. |
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
|
Hey @ArberSephirotheca , finally got time to report back. So here are my changes: So basically, use the macro types for most values, and as for the calculation for dot and acc, use f32. I did not see any accuracy differences between f32 and KV_TYPE directly here. However, when I set And results look quite similar (as I did not have a GT for this, so no scores). vs
Let me know if you want to have more details. Thanks. |
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
|
If we can keep the KV-cache shared memory as f16 with no meaningful reduction in accuracy, that would be great, since it reduces memory requirements and bandwidth across the board. |
|
Yeah, sg. I will create a PR later tonight for this. |
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
* fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp
Overview
In this PR, I addressed the precision issues for multimodal. More specifically, when mixed types are used in models and projectors, I use f32 for precision in the flash attention (more specifically, in the tile path) for the browser. I did not edit
flash_attn.wgslsincesubgroup_matrixisn't enabled in my test environment.Additional information
Inputs:

Tested model: LFM2.5-VL-450M-F16 with F16 mmproj.
Tested images:
Tested prompts: Describe this image in detail.
Here is the debugging process to help explain the editings. I calculated the cosine similarity between the embedding layers of the CPU backend and the WebGPU backend.
Without any changes on the master branch,
CLIP Vision Stage Comparison (WebGPU vs. C++ Parity) (Table formated by LLM)
Kcur-0f32[64, 12, 1024, 1]Vcur-0f32[64, 12, 1024, 1]attn_out-0f32[768, 1024, 1, 1]ffn_inp-0f32[768, 1024, 1, 1]ffn_inp_normed-0f32[768, 1024, 1, 1]ffn_out-0f32[768, 1024, 1, 1]Results from models:
The image shows a variety of people with different styles of clothing and accessories, but no specific details about the individuals. The image is primarily focused on a collection of abstract geometric shapes, including a group of people, that are not clearly defined or detailed. These shapes appear to be the main focus of the image.From these logs, we can see the cosine similarity discrepancy came from two main computation layers: attention, ffn. Related shaders include flash_attn, binary, norm, unary (like GELU). For example, f16 is more performance but less precise in online softmax operation for flash attention, and we noticed accumulated drifts. Therefore, first step was to use f32 and update the shared memory calculation logics for f32 buffers.
I first started with the flash_attn_tile and vec paths since I started debugging in the browser. And the results for attention layers (attn_out and also Vcur-0) had been increased from 0.98 to 1, which also increased the 1st layer precision (ffn_inp-0) after the attention layer.
CLIP Vision Stage Comparison (Updated Run)
Kcur-0f32[64, 12, 1024, 1]Vcur-0f32[64, 12, 1024, 1]attn_out-0f32[768, 1024, 1, 1]ffn_inp-0f32[768, 1024, 1, 1]ffn_inp_normed-0f32[768, 1024, 1, 1]ffn_out-0f32[768, 1024, 1, 1]layer_out-0f32[768, 1024, 1, 1]After some debugging and analysis, I then corrected gelu, gelu_quick and gelu_erf functions and used the pytorch implementation GELU — PyTorch 2.11 documentation
CLIP Vision Stage Comparison (Updated Run)
Kcur-0f32[64, 12, 1024, 1]Vcur-0f32[64, 12, 1024, 1]attn_out-0f32[768, 1024, 1, 1]ffn_inp-0f32[768, 1024, 1, 1]ffn_inp_normed-0f32[768, 1024, 1, 1]ffn_out-0f32[768, 1024, 1, 1]layer_out-0f32[768, 1024, 1, 1]Results:
The image features a character from the video game "The Legend of Zelda: Breath of the Wild". The character is depicted in a fantasy setting with a mystical ambiance. The character is standing in front of ancient ruins and surrounded by lush greenery and blue-lit trees, suggesting a sereneCorrecting the gelu functions improved the accuracy, but the final result was still incorrect, as we noticed that there were still some small offsets in the
Vcur-*layer, and these accumulated drifts caused final errors.|
embedding|f32|[768, 1024, 1, 1]| 0.74568836 | ❌ Critical Failure |The final root cause for this issue was that
flash_attn_tile.wgslsizedSCORE_REGS_PER_LANEfromMAX_SUBGROUP_SIZE, but the browser can run with a smaller runtimesubgroup_size. ForKV_TILE=64, that can make the tile process only part of K/V. I think this was why we saw some slight offsets onVcur-*andattn-out. So in this path, the shader now sizes per-lane arrays from MIN_SUBGROUP_SIZE.Also fixed the flash-attn pipeline cache key to include tile compile constants (q_tile, kv_tile, wg_size, subgroup sizes, SG matrix dims), and tile is now preferred over vec when tile is valid.
Results:
The image showcases a detailed digital illustration of a female warrior clad in elaborate, dark-toned armor. She wields a sword with a glowing blue blade, suggesting a supernatural or magical element. The setting is an ancient, possibly mystical, stone structure with columns and arches that frame her figure,|
Vcur-0|f32|[64, 12, 1024, 1]| 0.9999999 | ✅ Pass |...
|
layer_out-11|f32|[768, 1024, 1, 1]| 0.9999999 | ✅ Pass ||
embedding|f32|[768, 1024, 1, 1]| 0.99999999 | ✅ Pass |Requirements