[Quantization] FP8 quantization framework for diffusion attention#1413
[Quantization] FP8 quantization framework for diffusion attention#1413lishunyang12 wants to merge 48 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: eb969b5a0b
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if HAS_FA3 and fa3_attn_func is not None: | ||
| out = fa3_attn_func( | ||
| query, | ||
| key, |
There was a problem hiding this comment.
Respect padding masks in native FP8 FlashAttention
When KV tensors are FP8 and FA3 is present, this branch calls fa3_attn_func directly and bypasses the existing masked/unpadded path in forward_cuda. That means attn_metadata.attn_mask is not applied for padded batches, so enabling KV FP8 can change attention semantics (queries attend to padding tokens) and produce incorrect outputs for variable-length prompts.
Useful? React with 👍 / 👎.
| if self._kv_quant_enabled: | ||
| key, value, attn_metadata = self._quantize_kv(key, value, attn_metadata) |
There was a problem hiding this comment.
Gate FP8 KV quantization before ring attention
KV tensors are quantized to FP8 before the ring/local dispatch, so ring mode (ring_degree > 1) receives FP8 K/V even though the ring kernels consume raw q/k/v and do not use k_scale/v_scale to descale or dequantize. In this configuration, values are interpreted at the wrong scale (or can fail in non-FP8 kernels), which can corrupt ring-attention results when KV quantization is enabled.
Useful? React with 👍 / 👎.
|
KV quant may need more discussion |
I will give more context to show my decision-making process. There are a lot of options to implement this PRs and each one has trade-off, I am not super sure which one serves the purpose best. |
|
I've posted a detailed design rationale and open questions as a separate issue: #1454 This covers the decision-making process for FP8 KV quantization (why dynamic per-tensor, the dual FA3/fallback strategy, and the quantization point), acknowledges the two P1 correctness issues (padding mask bypass and ring attention incompatibility), and proposes fixes for each. @hsliuustc0106 would appreciate your input on the open questions there, especially around whether KV quant should be a separate config or tied to the existing |
|
Hi @lishunyang12 👋 Checking in on this FP8 KV quantization PR — it's been 15 days since the last update. Any progress to share? Thanks! |
eb969b5 to
5a4daea
Compare
5a4daea to
457c18a
Compare
457c18a to
d721cce
Compare
|
@hsliuustc0106 Sorry for the long delay — got held up on some other work. Just force-pushed a rebased version on top of current main. Main changes since the original:
Still need to run the full test plan (roundtrip, SDPA smoke, FA3 native, memory profiling). Would appreciate any feedback on the overall approach, especially whether |
d721cce to
63e3feb
Compare
Reduce attention K/V memory by ~50% via per-tensor dynamic FP8 quantization. On Hopper GPUs with FA3, this also accelerates attention via native FP8 tensor cores; on FA2/SDPA backends, K/V are dequantized before the kernel (memory-only benefit). - Add quantize_kv_fp8() / dequantize_fp8() utilities in vllm_omni/quantization/ - Add kv_quantization field to OmniDiffusionConfig - Add k_scale / v_scale fields to AttentionMetadata - Quantize K/V (+ joint K/V) in Attention.forward() after pre_attention - FA3 native FP8 path with descale_k/descale_v in FlashAttentionImpl - Dequant fallback for padded batches (varlen path) and SDPA backend - Guard against ring attention + FP8 KV (incompatible) - Add --kv-quantization CLI flag to text_to_image.py example - Add unit tests for roundtrip, scales, zero tensor, config integration Signed-off-by: lishunyang <lishunyang12@163.com>
63e3feb to
3524dcb
Compare
|
Should we change |
|
Good idea. I'll rename |
|
Did some investigation on how upstream vLLM implements FP8 KV cache. Here's what's relevant for alignment: Upstream implementation:
What we should align:
What we intentionally diverge on:
Relevant upstream files for reference:
|
…LI rename, joint scales Signed-off-by: lishunyang <lishunyang12@163.com>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
…work Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…tforms Signed-off-by: lishunyang <lishunyang12@163.com>
|
@lyj-jjj Thanks for the detailed feedback — both points are now addressed in the latest push. FP8 conversion moved into the attention backend. The layer now only sets Per-platform extensibility for NPU. To add FP8 on NPU, you would:
No changes to the layer, base class, or other backends needed. Unsupported platform+dtype combos are caught automatically with a warning. |
…A platforms Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Co-authored-by: Canlin Guo <961750412@qq.com> Signed-off-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
|
@lyj-jjj Following up — I've seen your RFCs (#2438, #2236, #2592) for NPU FP8 quantization. The framework in this PR directly enables your P0 (FA online FP8, #2236). For the NPU FA FP8 path, you would:
The quantization logic is fully owned by the backend, so you can use your own For P1 (MM/linear FP8, #2592) — that's orthogonal to this PR (weight quantization vs attention quantization). The existing Happy to coordinate if you need any changes to the framework to support the NPU path. |
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
TFLOPS metric, CUDA events timing, L2 flush, sweep mode. Ref: https://github.com/thu-ml/SageAttention/tree/main/bench Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Priority: SageAttn > FlashAttn > SDPA. SageAttn2 v2.2.0 with SM90 FP8 kernels is 8% faster than FA3 on H100 for HunyuanVideo 1.5 (4.00 vs 4.35 s/it). Signed-off-by: lishunyang <lishunyang12@163.com>
|
@lishunyang12 Thanks for the work. Have we tested fa3-fp8 on HunyuanImage 3.0 DiT part? |
|
@lishunyang12 Hello, any updates? |
Summary
Adds an extensible FP8 quantization framework for diffusion attention, targeting CUDA (Hopper FA3) first with per-platform extension points for NPU/XPU/ROCm.
Key Design
kv_cache_dtype="fp8"via metadata. Each backend decides whether and how to quantize — non-FP8 backends (SDPA) skip it entirely, no wasted quant/dequant cycle._supported_kv_cache_dtypesdict maps platform → supported dtypes. Currently CUDA only. NPU/XPU contributors uncomment one line and implementforward_npu()._PLATFORM_DISPATCHreplaces if/elif chain. New platforms register by adding an entry.AttentionBackend.supports_kv_cache_dtype()catches unsupported configs before model loading._handle_kv_cache_dtype()warns and clears unsupported dtypes before platform dispatch — no silent corruption.--kv-cache-dtype fp8flag,is_quantized_kv_cache()utility.When FP8 Helps
FP8 acceleration depends on how much of the pipeline is spent in attention. Attention is O(n²) while FFN is O(n), so longer sequences = higher attention fraction = bigger FP8 gains.
Best for: Long video generation, high-resolution images (2K+), large models with long sequences.
Limited benefit: Small images (1024²), CPU-offloaded models (PCIe bottleneck), low-step turbo models.
Precision
Uses fast quantization (scale=1.0, direct saturating cast to
float8_e4m3fn). Safe for diffusion models where Q/K/V values are typically in [-15, 15], well within FP8 e4m3fn range of ±448. FP8 e4m3 has 3-bit mantissa (~1/8 precision vs BF16's 7-bit), but softmax normalization + residual connections prevent quantization error from accumulating across layers.Benchmark Results (H100 80GB)
Results confirm the scaling pattern: video models with long sequences see the most benefit.
Visual Comparison
HunyuanVideo 1.5 (480×832, 33 frames, 50 steps)
BF16 (38.4s)
baseline.mp4
FP8 (34.1s)
fp8.mp4
Z-Image Turbo 1024×1024 (BF16 vs FP8)
Z-Image Turbo 2048×2048 (BF16 vs FP8)
Qwen-Image 2512×2512 (BF16 vs FP8)
FLUX.2-dev 1024×1024 (BF16 vs FP8)
Usage
Adding FP8 for a New Platform
References
Known Limitations
_forward_fp8uses FA3's non-varlen API which doesn't handle padding masks. Currently safe because diffusion runs batch_size=1 with equal-length sequences. Future batch inference with mixed resolutions would need a padding fallback.Test Plan
pytest tests/diffusion/quantization/test_kv_quant.py)Closes #1055