feat: FA4 flash att supports fused fp8 output#135
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: bd91f6e4dc
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
bd91f6e to
b8b18ba
Compare
afa774a to
b8646d3
Compare
|
Hi @LucasWilkinson @MatthewBonanni, would like your reviews on this PR. |
ProExpertProg
left a comment
There was a problem hiding this comment.
Nice work! Two high-level questions:
- We shouldn't pass the scale around as a float32, it will cause an unnecessary GPU-CPU sync.
- Can we benchmark what happens if we read the scale from global memory and don't invert it ahead of time? If it's the same, that might be better as ahead-of-time inversion costs more if done in the forward pass.
- I'm not sure adding
quant_kwargsis the best idea, we're tying flash-attn source to vLLM structure. I'm actually wondering if we can just communicate the necessary parameters via scale metadata; if we pass the scale tensor output as a parameter, its dimensions and strides should be enough to infer the other params: group size from shape, column major and tma alignment from strides, which only leavesue8m0, so just a single boolean. That way we don't need to passquant_keyaround either
61a1d0b to
f96d04f
Compare
|
@ProExpertProg thanks for the comments, updated! also compared using rcp_approx, and no perf difference |
6e973a9 to
631e1f2
Compare
MatthewBonanni
left a comment
There was a problem hiding this comment.
This looks very clean to me, thanks! Please make the corresponding vllm-side PR so we can use vLLM CI to verify this
Squashed from 10 commits of PR vllm-project#135 for rebase onto main. Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
91fb309 to
96e8957
Compare
| assert qv is None, "fused FP8 output + MLA (qv) not supported yet" | ||
| assert not use_dedicated_hd256_kernel, ( | ||
| "fused FP8 output + head_dim=256 kernel not supported yet" | ||
| ) |
There was a problem hiding this comment.
there are two new SM100 kernels after rebase, leave them out of scope for now -- will be fast followed once current PR nails the overall structure.
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
28b41f6 to
e577a7a
Compare
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Purpose
Part of vLLM issue#35792 of MLA + quant fusion. Specifically,
quant_keyandquant_kwargsfrom vllm.Test Plan
Test Results
benchmark result
note: the "fused fp8" is slighter quicker than "bf16 attn", likely due to less mem write.
pytest result
cute/test_flash_attn_fp8_output.py: 49 passed, 1117 warnings in 84.39s (0:01:24)
``` ============================= test session starts ============================== platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0 -- /root/flash-attention/.venv/bin/python cachedir: .pytest_cache rootdir: /root/flash-attention/tests configfile: pyproject.toml collecting ... collected 49 itemscute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-False-dtype0] PASSED [ 2%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-False-dtype1] PASSED [ 4%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-True-dtype0] PASSED [ 6%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-True-dtype1] PASSED [ 8%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-False-dtype0] PASSED [ 10%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-False-dtype1] PASSED [ 12%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-True-dtype0] PASSED [ 14%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-True-dtype1] PASSED [ 16%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-False-dtype0] PASSED [ 18%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-False-dtype1] PASSED [ 20%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-True-dtype0] PASSED [ 22%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-True-dtype1] PASSED [ 24%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-False-dtype0] PASSED [ 26%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-False-dtype1] PASSED [ 28%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-True-dtype0] PASSED [ 30%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-True-dtype1] PASSED [ 32%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-False-dtype0] PASSED [ 34%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-False-dtype1] PASSED [ 36%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-True-dtype0] PASSED [ 38%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-True-dtype1] PASSED [ 40%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-False-dtype0] PASSED [ 42%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-False-dtype1] PASSED [ 44%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-True-dtype0] PASSED [ 46%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-True-dtype1] PASSED [ 48%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-False-dtype0] PASSED [ 51%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-False-dtype1] PASSED [ 53%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-True-dtype0] PASSED [ 55%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-True-dtype1] PASSED [ 57%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-False-dtype0] PASSED [ 59%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-False-dtype1] PASSED [ 61%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-True-dtype0] PASSED [ 63%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-True-dtype1] PASSED [ 65%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-False-dtype0] PASSED [ 67%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-False-dtype1] PASSED [ 69%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-True-dtype0] PASSED [ 71%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-True-dtype1] PASSED [ 73%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_varlen_deepseek_mla PASSED [ 75%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_auto_allocate PASSED [ 77%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_as_tensor PASSED [ 79%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_sliding_window[causal_local_left] PASSED [ 81%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_sliding_window[symmetric_local] PASSED [ 83%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_sliding_window[causal_full] PASSED [ 85%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_softcap[15.0] PASSED [ 87%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_softcap[30.0] PASSED [ 89%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_extremes[scale_underuses_range] PASSED [ 91%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_extremes[scale_matches_peak] PASSED [ 93%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_extremes[scale_overuses_range] PASSED [ 95%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_split_kv PASSED [ 97%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_validation_errors PASSED [100%]
=============================== warnings summary ===============================
cute/test_flash_attn_fp8_output.py: 1117 warnings
/root/flash-attention/.venv/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:63: DeprecationWarning:
make_fragmentis deprecated, usemake_rmem_tensorinsteadres_or_list = opFunc(*args, **kwargs, loc=loc)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================= 49 passed, 1117 warnings in 84.39s (0:01:24) =================