Skip to content

[Quantization] FP8 quantization framework for diffusion attention#1413

Draft
lishunyang12 wants to merge 48 commits intovllm-project:mainfrom
lishunyang12:fp8-kv-quantization
Draft

[Quantization] FP8 quantization framework for diffusion attention#1413
lishunyang12 wants to merge 48 commits intovllm-project:mainfrom
lishunyang12:fp8-kv-quantization

Conversation

@lishunyang12
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 commented Feb 20, 2026

Summary

Adds an extensible FP8 quantization framework for diffusion attention, targeting CUDA (Hopper FA3) first with per-platform extension points for NPU/XPU/ROCm.

Key Design

  • Backend-owned quantization: The attention layer signals kv_cache_dtype="fp8" via metadata. Each backend decides whether and how to quantize — non-FP8 backends (SDPA) skip it entirely, no wasted quant/dequant cycle.
  • Per-platform support: _supported_kv_cache_dtypes dict maps platform → supported dtypes. Currently CUDA only. NPU/XPU contributors uncomment one line and implement forward_npu().
  • Table-driven dispatch: _PLATFORM_DISPATCH replaces if/elif chain. New platforms register by adding an entry.
  • Init-time validation: AttentionBackend.supports_kv_cache_dtype() catches unsupported configs before model loading.
  • Runtime guard: _handle_kv_cache_dtype() warns and clears unsupported dtypes before platform dispatch — no silent corruption.
  • Aligned with upstream vLLM: --kv-cache-dtype fp8 flag, is_quantized_kv_cache() utility.

When FP8 Helps

FP8 acceleration depends on how much of the pipeline is spent in attention. Attention is O(n²) while FFN is O(n), so longer sequences = higher attention fraction = bigger FP8 gains.

Sequence Length Attention Fraction Expected FP8 Speedup
~1K tokens (1024² image) ~10-15% Negligible
~4K tokens (2048² image) ~25-30% Modest (~1.06×)
~13K tokens (33-frame video) ~40-50% Noticeable (~1.13×)
~50K tokens (121-frame video) ~60-70% Significant (~1.2×+)

Best for: Long video generation, high-resolution images (2K+), large models with long sequences.
Limited benefit: Small images (1024²), CPU-offloaded models (PCIe bottleneck), low-step turbo models.

Precision

Uses fast quantization (scale=1.0, direct saturating cast to float8_e4m3fn). Safe for diffusion models where Q/K/V values are typically in [-15, 15], well within FP8 e4m3fn range of ±448. FP8 e4m3 has 3-bit mantissa (~1/8 precision vs BF16's 7-bit), but softmax normalization + residual connections prevent quantization error from accumulating across layers.

Benchmark Results (H100 80GB)

Model Resolution / Frames BF16 FP8 Speedup
HunyuanVideo 1.5 480×832, 33 frames 38.4s 34.1s 1.13×
Z-Image Turbo 1024×1024 6.55s 6.55s 1.00×
Z-Image Turbo 2048×2048 11.1s 10.4s 1.06×
FLUX.2-dev 1024×1024 (CPU offload) 63.3s 62.4s 1.01×

Results confirm the scaling pattern: video models with long sequences see the most benefit.

Visual Comparison

HunyuanVideo 1.5 (480×832, 33 frames, 50 steps)

BF16 (38.4s)

baseline.mp4

FP8 (34.1s)

fp8.mp4

Z-Image Turbo 1024×1024 (BF16 vs FP8)

BF16 FP8

Z-Image Turbo 2048×2048 (BF16 vs FP8)

BF16 FP8

Qwen-Image 2512×2512 (BF16 vs FP8)

BF16 FP8

FLUX.2-dev 1024×1024 (BF16 vs FP8)

BF16 FP8

Usage

# Text-to-video with FP8
python examples/offline_inference/text_to_video/text_to_video.py \
    --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v \
    --kv-cache-dtype fp8 \
    --num-frames 33 --num-inference-steps 50

# Text-to-image with FP8
python examples/offline_inference/text_to_image/text_to_image.py \
    --model Tongyi-MAI/Z-Image-Turbo \
    --kv-cache-dtype fp8 \
    --height 2048 --width 2048

Adding FP8 for a New Platform

# 1. Declare support in FlashAttentionImpl
_supported_kv_cache_dtypes = {
    "cuda": {"fp8", "fp8_e4m3"},
    "npu": {"fp8"},  # ← uncomment
}

# 2. Handle in forward_npu()
def forward_npu(self, query, key, value, attn_metadata):
    if is_quantized_kv_cache(attn_metadata.kv_cache_dtype):
        return self._forward_fp8_npu(...)
    # standard path...

References

Known Limitations

  • Hopper 14-bit accumulator: On Hopper GPUs, FP8 Tensor Core uses a 22-bit accumulator (1 sign + 8 exponent + 13 mantissa, vs FP32's 32 bits — the last 10 mantissa bits are truncated). DeepSeek V3 paper describes this as "14 bits" (sign + mantissa only); SageAttention2 measured 22 effective bits (including exponent). Same hardware behavior, different counting. For very long sequences (121 frames / 50K+ tokens), accumulated error can corrupt attention output (black screen). The fix is two-level accumulation (promote to CUDA Core FP32 every N WGMMAs), standard in CUTLASS ≥3.2 but not yet in upstream FA3. Blackwell (B200/SM100) largely solves this: testing shows the FP8 accumulator mantissa increased from 13 bits (Hopper) to 25 bits (Blackwell), exceeding FP32's 23-bit mantissa. Two-level accumulation may no longer be needed on Blackwell. only uses the highest 14 bits deepseek-ai/DeepGEMM#37
  • No padding guard in FP8 path: _forward_fp8 uses FA3's non-varlen API which doesn't handle padding masks. Currently safe because diffusion runs batch_size=1 with equal-length sequences. Future batch inference with mixed resolutions would need a padding fallback.
  • Fast quant assumes bounded values: Scale=1.0 direct cast assumes Q/K/V values are within FP8 e4m3fn range (±448). Empirically true for tested diffusion models (values typically [-15, 15]), but not guaranteed for all architectures. No runtime validation currently.

Test Plan

  • 15/15 unit tests pass (pytest tests/diffusion/quantization/test_kv_quant.py)
  • E2E: HunyuanVideo 1.5 (480p, 33 frames) — correct output, 1.13× speedup
  • E2E: Z-Image Turbo (1024, 2048) — correct output, 1.06× at 2K
  • E2E: FLUX.2-dev (1024, CPU offload) — correct output
  • E2E: Qwen-Image-2512 — correct output
  • SDPA fallback: warns and runs in native dtype (no crash)

Closes #1055

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: eb969b5a0b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +167 to +170
if HAS_FA3 and fa3_attn_func is not None:
out = fa3_attn_func(
query,
key,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Respect padding masks in native FP8 FlashAttention

When KV tensors are FP8 and FA3 is present, this branch calls fa3_attn_func directly and bypasses the existing masked/unpadded path in forward_cuda. That means attn_metadata.attn_mask is not applied for padded batches, so enabling KV FP8 can change attention semantics (queries attend to padding tokens) and produce incorrect outputs for variable-length prompts.

Useful? React with 👍 / 👎.

Comment thread vllm_omni/diffusion/attention/layer.py Outdated
Comment on lines +156 to +157
if self._kv_quant_enabled:
key, value, attn_metadata = self._quantize_kv(key, value, attn_metadata)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate FP8 KV quantization before ring attention

KV tensors are quantized to FP8 before the ring/local dispatch, so ring mode (ring_degree > 1) receives FP8 K/V even though the ring kernels consume raw q/k/v and do not use k_scale/v_scale to descale or dequantize. In this configuration, values are interpreted at the wrong scale (or can fail in non-FP8 kernels), which can corrupt ring-attention results when KV quantization is enabled.

Useful? React with 👍 / 👎.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

KV quant may need more discussion

@lishunyang12
Copy link
Copy Markdown
Collaborator Author

KV quant may need more discussion

I will give more context to show my decision-making process. There are a lot of options to implement this PRs and each one has trade-off, I am not super sure which one serves the purpose best.

@lishunyang12
Copy link
Copy Markdown
Collaborator Author

I've posted a detailed design rationale and open questions as a separate issue: #1454

This covers the decision-making process for FP8 KV quantization (why dynamic per-tensor, the dual FA3/fallback strategy, and the quantization point), acknowledges the two P1 correctness issues (padding mask bypass and ring attention incompatibility), and proposes fixes for each.

@hsliuustc0106 would appreciate your input on the open questions there, especially around whether KV quant should be a separate config or tied to the existing --quantization fp8 flag.

@hsliuustc0106
Copy link
Copy Markdown
Collaborator

Hi @lishunyang12 👋

Checking in on this FP8 KV quantization PR — it's been 15 days since the last update. Any progress to share?

Thanks!

@lishunyang12
Copy link
Copy Markdown
Collaborator Author

@hsliuustc0106 Sorry for the long delay — got held up on some other work.

Just force-pushed a rebased version on top of current main. Main changes since the original:

Still need to run the full test plan (roundtrip, SDPA smoke, FA3 native, memory profiling). Would appreciate any feedback on the overall approach, especially whether kv_quantization as a standalone config field makes sense vs. being tied to the weight quant config.

Reduce attention K/V memory by ~50% via per-tensor dynamic FP8
quantization. On Hopper GPUs with FA3, this also accelerates
attention via native FP8 tensor cores; on FA2/SDPA backends,
K/V are dequantized before the kernel (memory-only benefit).

- Add quantize_kv_fp8() / dequantize_fp8() utilities in vllm_omni/quantization/
- Add kv_quantization field to OmniDiffusionConfig
- Add k_scale / v_scale fields to AttentionMetadata
- Quantize K/V (+ joint K/V) in Attention.forward() after pre_attention
- FA3 native FP8 path with descale_k/descale_v in FlashAttentionImpl
- Dequant fallback for padded batches (varlen path) and SDPA backend
- Guard against ring attention + FP8 KV (incompatible)
- Add --kv-quantization CLI flag to text_to_image.py example
- Add unit tests for roundtrip, scales, zero tensor, config integration

Signed-off-by: lishunyang <lishunyang12@163.com>
@david6666666
Copy link
Copy Markdown
Collaborator

Should we change --kv-quantization to --kv-cache-dtype to align with upstream vLLM?

@lishunyang12
Copy link
Copy Markdown
Collaborator Author

Good idea. I'll rename --kv-quantization to --kv-cache-dtype to align with upstream vLLM. This also makes it easier to extend — e.g. --kv-cache-dtype mxfp8 for #2236 later.

@lishunyang12
Copy link
Copy Markdown
Collaborator Author

Did some investigation on how upstream vLLM implements FP8 KV cache. Here's what's relevant for alignment:

Upstream implementation:

  • CLI: --kv-cache-dtype fp8 (config)
  • Quantization happens at cache-write time inside a CUDA kernel (reshape_and_cache_flash_kernel) — needed because LLM has paged KV cache
  • FA3 path: Q is also quantized via QuantFP8 module, descale_q/k/v all passed to FA3 (attention.py)
  • Non-FA3: hard NotImplementedError — no dequant fallback
  • Dynamic scale computation (calculate_kv_scales) is deprecated, being removed in v0.19
  • Scale loading from checkpoint: kv_cache.py

What we should align:

  1. ✅ Rename --kv-quantization--kv-cache-dtype (already agreed above)
  2. Add Q quantization + descale_q for full FP8 FA3 benefit
  3. Align scale naming: _k_scale / _v_scale / _q_scale

What we intentionally diverge on:

  • No CUDA kernel needed — diffusion has no persistent KV cache, K/V are computed fresh each step. PyTorch-level tensor.to(float8_e4m3fn) is sufficient.
  • Per-call dynamic scale (not one-shot) — diffusion KV range shifts across timesteps, so recomputing scale each call is correct.

Relevant upstream files for reference:

…LI rename, joint scales

Signed-off-by: lishunyang <lishunyang12@163.com>
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

…work

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…tforms

Signed-off-by: lishunyang <lishunyang12@163.com>
@lishunyang12
Copy link
Copy Markdown
Collaborator Author

lishunyang12 commented Apr 7, 2026

@lyj-jjj Thanks for the detailed feedback — both points are now addressed in the latest push.

FP8 conversion moved into the attention backend. The layer now only sets attn_metadata.kv_cache_dtype = "fp8" as a signal. Each backend decides whether and how to quantize internally. SDPA simply skips it (no wasted quant/dequant cycle).

Per-platform extensibility for NPU. To add FP8 on NPU, you would:

  1. Uncomment "npu": {"fp8"} in FlashAttentionImpl._supported_kv_cache_dtypes
  2. Handle kv_cache_dtype in forward_npu() with your own FP8 operators

No changes to the layer, base class, or other backends needed. Unsupported platform+dtype combos are caught automatically with a warning.

…A platforms

Signed-off-by: lishunyang <lishunyang12@163.com>
@lishunyang12 lishunyang12 changed the title [Quantization] Add FP8 KV quantization for diffusion attention layers [Quantization] FP8 KV cache quantization framework for diffusion attention Apr 7, 2026
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Comment thread vllm_omni/diffusion/attention/backends/abstract.py Outdated
Co-authored-by: Canlin Guo <961750412@qq.com>
Signed-off-by: SYLAR <125541396+lishunyang12@users.noreply.github.com>
@lishunyang12
Copy link
Copy Markdown
Collaborator Author

@lyj-jjj Following up — I've seen your RFCs (#2438, #2236, #2592) for NPU FP8 quantization. The framework in this PR directly enables your P0 (FA online FP8, #2236).

For the NPU FA FP8 path, you would:

  1. Uncomment "npu": {"fp8"} in FlashAttentionImpl._supported_kv_cache_dtypes
  2. In forward_npu(), check attn_metadata.kv_cache_dtype and call your mindiesd FP8 operators (rotation + block quant + FA)
  3. No changes needed to the layer, metadata, or other backends

The quantization logic is fully owned by the backend, so you can use your own FP8RotateQuantFA + fa_block_quant_preprocess pipeline inside forward_npu() without touching the CUDA path.

For P1 (MM/linear FP8, #2592) — that's orthogonal to this PR (weight quantization vs attention quantization). The existing QuantizationConfig framework (#1764) would be the right extension point there, similar to what was done for INT8 (#1470).

Happy to coordinate if you need any changes to the framework to support the NPU path.

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
TFLOPS metric, CUDA events timing, L2 flush, sweep mode.
Ref: https://github.com/thu-ml/SageAttention/tree/main/bench

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Priority: SageAttn > FlashAttn > SDPA.
SageAttn2 v2.2.0 with SM90 FP8 kernels is 8% faster than FA3
on H100 for HunyuanVideo 1.5 (4.00 vs 4.35 s/it).

Signed-off-by: lishunyang <lishunyang12@163.com>
@lishunyang12 lishunyang12 changed the title [Quantization] FP8 KV cache quantization framework for diffusion attention [Quantization] FP8 quantization framework for diffusion attention Apr 13, 2026
@Gaohan123
Copy link
Copy Markdown
Collaborator

@lishunyang12 Thanks for the work. Have we tested fa3-fp8 on HunyuanImage 3.0 DiT part?

@Gaohan123 Gaohan123 added this to the v0.20.0 milestone Apr 16, 2026
@Gaohan123
Copy link
Copy Markdown
Collaborator

@lishunyang12 Hello, any updates?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC]: FP8 Quantization for Key and Value Tensors in Diffusion Model Attention Layers

6 participants