Skip to content

[Quant] Wire quant_config through HunyuanVideo-1.5 and Wan2.2 DiT for online FP8#2920

Open
lishunyang12 wants to merge 11 commits into
vllm-project:mainfrom
lishunyang12:fp8-online-hunyuan-wan
Open

[Quant] Wire quant_config through HunyuanVideo-1.5 and Wan2.2 DiT for online FP8#2920
lishunyang12 wants to merge 11 commits into
vllm-project:mainfrom
lishunyang12:fp8-online-hunyuan-wan

Conversation

@lishunyang12
Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 commented Apr 19, 2026

Purpose

Thread quant_config through the HunyuanVideo-1.5 and Wan2.2 DiT transformers so that --quantization fp8 actually activates the FP8 kernel. Before this PR the flag populated od_config.quantization_config but never reached the transformer's QKVParallelLinear / RowParallelLinear / FeedForward constructors — every linear picked up UnquantizedLinearMethod and the loader's _process_weights_after_loading had nothing to quantize.

Proof the original #1516 FP8 claim was a silent no-op: FP8 vs BF16 rows in that PR's latency table are within 0.1 s of each other; model-load memory was identical. After this PR the engine log shows Selected CutlassFP8ScaledMMLinearKernel for Fp8OnlineLinearMethod on FP8 runs, and model-load memory drops measurably.

Pattern follows the #2728 / #2795 fix for Z-Image / Qwen-Image / FLUX.1-dev:

  • Main attention and FFN get quant_config= + prefix=
  • Modulation / AdaLayerNorm / entry / output-projection linears stay full precision
  • Cross-attention kept full precision (mirror of FLUX dual-stream)
  • .contiguous() guard at attention entry (FP8 kernels require it)

Fixes #2912.

Benchmark (1×H100 80GB)

HunyuanVideo-1.5 (480p T2V, 33 frames, 30 steps)

Config Weights (load) Time Kernel log
BF16 33.81 GiB 25.4 s
FP8 (this PR) 28.74 GiB (−15%) 23.5 s (−8%) Selected CutlassFP8ScaledMMLinearKernel

Wan2.2 (TI2V-5B, 704×1280 T2V)

Config Frames Peak VRAM Time
BF16 49 44.89 GiB 21.5 s
FP8 (this PR) 49 41.46 GiB (−8%) 19.6 s (−8%)
BF16 121 47.19 GiB 66.2 s
FP8 (this PR) 121 43.31 GiB (−8%) 59.6 s (−10%)

Seed 42, guidance 6.0 (HV-1.5) / 5.0 (Wan2.2). Tested with TI2V-5B (fits in 80 GiB BF16); A14B MoE wiring is identical but needs TP=2 and is out of scope here.

HunyuanVideo-1.5 FP8 preset ablation

Visual comparison of BF16 baseline against four per-layer FP8 presets (same prompt, same seed, 480×832, 33 frames, 30 steps). All FP8 configs reduce model-load memory; quality differences are primarily in fine detail / atmospheric texture.

Preset FP8 layers BF16 layers (kept full precision)
S1 ff + ff_context all attention + modulation
S2 to_qkv, to_out[0], ff (video stream only) encoder-stream attn + ff_context + modulation
S3 All except encoder cross-attn (add_kv_proj, to_add_out BF16) encoder cross-attn + modulation
S4 Everything in attention + FFN modulation, patch embed, proj_out, VAE, text encoders

BF16 baseline

hv15_bf16_seed42_f33_ablation.mp4

S1 — FFN only

hv15_fp8_S1_seed42_f33_ablation.mp4

S2 — video stream only

hv15_fp8_S2_seed42_f33_ablation.mp4

S3 — all FP8 except encoder cross-attn

hv15_fp8_S3_seed42_f33_ablation.mp4

S4 — everything FP8

hv15_fp8_S4_seed42_f33_ablation.mp4

Observation: All four presets (S1 → S4) show similar quality reduction relative to BF16, with no visually decisive winner. This suggests the FFN path (common to every preset) is the primary source of FP8 drift — not attention. The shipped default is S4 (maximum memory savings, no measurable quality penalty over S1).

Known limitations

Online FP8 has visible quality reduction on video DiTs. Output stays coherent but fine detail — atmospheric depth, high-frequency texture, distant ridges — softens vs BF16. Same profile as #2795 shipped for Qwen-Image (LPIPS 0.32, threshold 0.35).

Block-wise FP8 is not available for online quantization. Fp8Config.__init__ in upstream vLLM (fp8.py:120-125) currently requires is_checkpoint_fp8_serialized=True for any weight_block_size. Block-wise would recover most of the quality gap but needs pre-quantized checkpoints. Tracked as follow-up.

Recommended use: memory-constrained workflows where the ~15% memory / ~10% speed tradeoff is worth the detail softening. For quality-critical rendering, leave --quantization none.

Model layers that remain FP8 (shipped config)

HunyuanVideo-1.5
Layer Role
HunyuanVideo15Attention.to_qkv · to_out[0] · add_kv_proj · to_add_out joint attention Q/K/V/output (video + encoder streams)
FeedForward.net[0].proj (GELU) · net[2] — both ff and ff_context per-block FFN on both streams

Kept full precision: modulation (raw nn.Linear), AdaLayerNormZero, patch embed, proj_out, VAE, text encoders (Qwen2.5-VL + ByT5 + SigLIP), token refiner path.

Matches Qwen-Image pattern from #2795. An env-var preset mechanism (HV15_FP8_PRESET) is also added for research sweeps (see ablation above); default S4 matches the behavior above.

Wan2.2 (T2V / I2V / TI2V / VACE)
Layer Role
WanFeedForward.net_0 (GELU via ColumnParallelGELU) · net_2 per-block FFN

Kept full precision:

Why FFN-only for Wan2.2: with full attention+FFN FP8, long-sequence outputs collapsed to noise (LPIPS 0.93). Cross-attn skip alone reduced it but left composition drift (LPIPS 0.52). FFN-only is the most aggressive setting that produced stable output on long videos.

Changes

HunyuanVideo-1.5

  • hunyuan_video_15_transformer.pyHunyuanVideo15Attention, HunyuanVideo15TransformerBlock, HunyuanVideo15Transformer3DModel thread quant_config / prefix to all attention + FFN linears; modulation layers stay raw nn.Linear.
  • pipeline_hunyuan_video_1_5.py + pipeline_hunyuan_video_1_5_i2v.py — pass quant_config=od_config.quantization_config to transformer.

Wan2.2 (all four pipelines)

Bench infrastructure

  • benchmarks/diffusion/bench_video_fp8.py — new; BF16 vs FP8 perf / peak-VRAM / LPIPS / PSNR / SSIM with --presets, --frames-list, --weight-block-size, MP4 output. Used to produce the tables and ablation above.

Test plan

  • HV-1.5 T2V 480p FP8 — kernel selects, weight memory drops 15%, output coherent
  • HV-1.5 FP8 preset ablation (S1 / S2 / S3 / S4) at 480p — all presets produce valid video, no decisive quality winner, S4 shipped as default
  • Wan2.2 TI2V-5B T2V FP8 — 49 frames and 121 frames both produce valid video
  • Visual inspection BF16 vs FP8 on both models at multiple frame counts
  • Pre-commit (ruff, format, typos) — passing
  • Selected CutlassFP8ScaledMMLinearKernel confirmed in engine log on FP8 runs
  • HV-1.5 I2V FP8 — wiring identical to T2V (untested)
  • Wan2.2 T2V-A14B MoE FP8 — needs 2×H100
  • Wan2.2 VACE FP8 — wiring threaded, not yet end-to-end tested
  • TP=2 across both models — wiring is TP-aware via existing parallel primitives

Follow-ups

  1. Static FP8 checkpoints for HV-1.5 and Wan2.2 via NVIDIA ModelOpt — unlocks block-wise FP8 (closer to BF16 quality). Tracked under "Static" column in [RFC]: Continuous Quantization Support  #1854.
  2. HV-1.5 seed propagationOmniDiffusionSamplingParams.seed doesn't propagate to HV-1.5's noise generator across the worker subprocess boundary, which blocks reliable LPIPS measurement. Separate bugfix PR.
  3. SageAttention FP8 backend — FP8 on the attention kernel itself (not just projections). Would close most of the remaining quality gap and deliver a more significant speedup. Separate design PR.
  4. Text encoder + VAE FP8 ([Quantization] Enable FP8 online quantization for Z-image text encoder #1338) — larger memory savings on HV-1.5's Qwen2.5-VL; VAE FP8 remains a research question (Conv3d kernel support).

cc @ArtificialRay @DarkLight1337

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
torch.Generator handles don't cross process boundaries; the worker subprocess was
using fresh RNG per generate call, producing unrelated BF16/FP8 outputs (LPIPS ~0.59).
Switch to integer seed= via sampling params (pickles correctly). Use pynvml polling
for peak VRAM since torch.cuda.max_memory_allocated reads 0 in the caller.

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
…n't lose earlier results

Signed-off-by: lishunyang <lishunyang12@163.com>
Handle all common dim orders ([T,H,W,C], [C,T,H,W], [T,C,H,W]) and print raw/normalized
shapes so anyone hitting the ValueError can see what came back.

Signed-off-by: lishunyang <lishunyang12@163.com>
FP8 on the text-conditioning joint attention collapses output to noise.
Mirror FLUX dual-stream fix (vllm-project#2728): keep cross-attn BF16, keep self-attn
and FFN quantized.

Signed-off-by: lishunyang <lishunyang12@163.com>
Signed-off-by: lishunyang <lishunyang12@163.com>
Cross-attn skip alone still caused visible quality loss on long sequences
(astronaut-on-Mars test at 121 frames: composition shift, detail loss).
Keep attn1 full precision and quantize only FFN — same pattern as FLUX
keeping dual-stream BF16 and FP8'ing single-stream only (vllm-project#2728).
Memory gain is smaller but output quality matches BF16.

Signed-off-by: lishunyang <lishunyang12@163.com>
Transformer gains env-var-driven preset resolution for per-role FP8 selection:
  BF16 - nothing quantized
  S1   - FFN only (ff + ff_context)
  S2   - video stream only (to_qkv + to_out[0] + ff)
  S3   - all FP8 except encoder cross-attn (keeps add_kv_proj/to_add_out BF16)
  S4   - everything FP8 (default, pre-sweep behavior)

Bench script gains --presets and --frames-list to sweep the matrix in one run,
caches BF16 per frame count, emits combined markdown table to results.md.

Signed-off-by: lishunyang <lishunyang12@163.com>
Passes a dict spec ({method: fp8, weight_block_size: [M, N]}) to Omni when
specified, falls back to the 'fp8' string for per-tensor. Block-wise scales
typically recover most of the BF16-vs-FP8 quality gap at a small perf cost.

Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request Apr 19, 2026
…vllm-project#2920)

Threads quant_config / prefix through HunyuanVideo15Attention,
HunyuanVideo15TransformerBlock, and HunyuanVideo15Transformer3DModel so
the modelopt FP8 adapter from vllm-project#2913 has somewhere to bind per-layer scales.
Modulation, embeddings, proj_out stay raw nn.Linear (full precision).

Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request Apr 19, 2026
…eo-1.5

examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py:
  Offline calibration helper that produces a ModelOpt FP8 diffusers checkpoint
  for HunyuanVideo-1.5. Calibrates with 8 video prompts x 10 denoising steps,
  skips precision-sensitive layers (modulation, embeddings, output proj,
  token refiner) matching the vllm-project#2728 / vllm-project#2795 pattern, disables MHA quantizers
  by default (HV-1.5 self-attention degrades visibly under FP8 - see vllm-project#2920).

vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml:
  Stage config for serving the calibrated checkpoint via vllm-omni. Auto-detects
  ModelOpt metadata from the checkpoint (uses vllm-project#2913's adapter).

Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request Apr 19, 2026
…ject#2920)

Threads quant_config / prefix through WanSelfAttention, WanCrossAttention,
WanFeedForward (+ ColumnParallelGELU), WanTransformerBlock, and
WanTransformer3DModel / WanVACETransformer3DModel, plus the four pipelines
(T2V / I2V / TI2V / VACE). Modulation (scale_shift_table), patch_embedding
(Conv3d), time/text/image embedders, and proj_out stay full precision.

All attention + FFN linears receive quant_config so the ModelOpt FP8 adapter
from vllm-project#2913 can bind per-layer scales at load time. The aggressive skip
patterns from vllm-project#2920 (attn1/attn2 quant_config=None) are NOT applied here —
that was an online-FP8 quality workaround; static calibration handles it.

Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request Apr 19, 2026
…ject#2920)

Threads quant_config / prefix through WanSelfAttention, WanCrossAttention,
WanFeedForward (+ ColumnParallelGELU), WanTransformerBlock, and
WanTransformer3DModel / WanVACETransformer3DModel, plus the four pipelines
(T2V / I2V / TI2V / VACE). Modulation (scale_shift_table), patch_embedding
(Conv3d), time/text/image embedders, and proj_out stay full precision.

All attention + FFN linears receive quant_config so the ModelOpt FP8 adapter
from vllm-project#2913 can bind per-layer scales at load time. The aggressive skip
patterns from vllm-project#2920 (attn1/attn2 quant_config=None) are NOT applied here —
that was an online-FP8 quality workaround; static calibration handles it.

Signed-off-by: lishunyang <lishunyang12@163.com>
@david6666666 david6666666 marked this pull request as ready for review April 20, 2026 14:23
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request May 2, 2026
…vllm-project#2920)

Threads quant_config / prefix through HunyuanVideo15Attention,
HunyuanVideo15TransformerBlock, and HunyuanVideo15Transformer3DModel so
the modelopt FP8 adapter from vllm-project#2913 has somewhere to bind per-layer scales.
Modulation, embeddings, proj_out stay raw nn.Linear (full precision).

Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request May 2, 2026
…eo-1.5

examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py:
  Offline calibration helper that produces a ModelOpt FP8 diffusers checkpoint
  for HunyuanVideo-1.5. Calibrates with 8 video prompts x 10 denoising steps,
  skips precision-sensitive layers (modulation, embeddings, output proj,
  token refiner) matching the vllm-project#2728 / vllm-project#2795 pattern, disables MHA quantizers
  by default (HV-1.5 self-attention degrades visibly under FP8 - see vllm-project#2920).

vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml:
  Stage config for serving the calibrated checkpoint via vllm-omni. Auto-detects
  ModelOpt metadata from the checkpoint (uses vllm-project#2913's adapter).

Signed-off-by: lishunyang <lishunyang12@163.com>
lishunyang12 added a commit to lishunyang12/vllm-omni that referenced this pull request May 2, 2026
…ject#2920)

Threads quant_config / prefix through WanSelfAttention, WanCrossAttention,
WanFeedForward (+ ColumnParallelGELU), WanTransformerBlock, and
WanTransformer3DModel / WanVACETransformer3DModel, plus the four pipelines
(T2V / I2V / TI2V / VACE). Modulation (scale_shift_table), patch_embedding
(Conv3d), time/text/image embedders, and proj_out stay full precision.

All attention + FFN linears receive quant_config so the ModelOpt FP8 adapter
from vllm-project#2913 can bind per-layer scales at load time. The aggressive skip
patterns from vllm-project#2920 (attn1/attn2 quant_config=None) are NOT applied here —
that was an online-FP8 quality workaround; static calibration handles it.

Signed-off-by: lishunyang <lishunyang12@163.com>
@Gaohan123
Copy link
Copy Markdown
Collaborator

@lishunyang12 Hello, any updates?

@david6666666
Copy link
Copy Markdown
Collaborator

please resolve conflicts, thx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Doc]: Is HunyuanVideo-1.5 really support fp8 dynamic quantization

3 participants