Phase1 (video-gen) ModelOpt FP8 Follow-ups#57
Open
ArtificialRay wants to merge 26 commits into
Open
Conversation
…om for bf16 model
…om for fp8 model(for vae)
…ut calculation, and rewrite quant_quality script to automate model offline quant
Signed-off-by: ArtificialRay <shuaiweihuang@163.com>
Owner
|
Thanks for your contribution. May i know on which device did you test modelopt fp8? |
Author
Thanks for reply. I use H100 80GB for modelopt fp8 test |
Owner
May i have your contact? |
Author
|
You can contact with me via wechat: ArthurRay2333 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR completes follow-ups listed in vllm-project#2924. All benchmarking at Validation is done in
benchmarks/diffusion/quantization_quality.pyPurpose
Phase 1 of vllm-project#2709 — extends ModelOpt FP8 support to video-gen models. This PR add ModelOpt FP8 static quantization support for Wan2.2 T2V-A14B / I2V-A14B MoE variants, Wan2.2 VACE variant, HunyuanVideo-1.5 720p T2V/I2V variants, and Block-wise static FP8 quantization for all above model variants.
Changes
ModelOpt FP8 helpers
examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py-- HV-1.5 calibrator that add I2V variant support and patchesquant_algo: FP8if per-tensor quant or patchesquant_algo: FP8_PB_WOif per-block quant. Per-block quant support only if M=N=128examples/quantization/quantize_wan2_2_modelopt_fp8.py-- Wan2.2 TI2V-5B / T2V-A14B / I2V-A14B calibrator, for A14B model, it loads two pipeline to two separateexamples/quantization/check_modelopt_fp8_export.py-- Verifier. Add on-disk transformer weight reduction metrics and whole-repo(whole-model) weight size reduction.examples/offline_inference/vace/vace_video_generation.py-- Wan2.2 VACE variance calibrator, support T2V / I2V/ R2V per-tensor or per-block static quantizationCalibrator still share
--weight-block-size 'M,N'for block-wise FP8, and use the same fallback pattern:_force_export_quantized_weights+_patch_quant_config+hide_quantizers_from_state_dictas vllm-project#2924Script wiring
examples/offline_inference/image_to_video/image_to_video.py-- add--quantizationand--ignore-layersto support I2V FP8 quantization testbenchmarks/diffusion/quantization_quality.py-- support both T2V and I2V video-gen model quantization quality benchmarking. Add metricsthroughput,peak VRAMandpeak VRAM reduction. This PR also fix bugs that previousMemorymetric is always 0.0, and T2I benchmarking attribute error atgenerate_image()Adapter
modelopt_fp8.py:_is_transformer_sourceadd support for Wan2.2 A14B MoE that verify two transformer architectureValidation --- Wan2.2-I2V-A14B (1x H100 80GB, I2V 720x1280, 81 frames, 50 steps, seed=42)
torch.compileenabled (default).--vae-use-tilingis set during benchmarking as there will be CUDA OOM on BF16 baseline if not enable.BF16 baseline v.s. Per-tensor quantization
BF16 baseline v.s. Per-block quantization
Engine signals confirming the path is wired correctly:
factory.py: Building quantization config: fp8→Building quantization config: modelopt— auto-detect upgraded the user's--quantization fp8flag to ModelOpt based onquant_algo: FP8orquant_algo: FP8_PB_WOin transformer/config.jsondata.py: Auto-detected quantization 'modelopt' from model configmodelopt.py:381(Per-tensor) Detected ModelOpt fp8 checkpoint (quant_algo=FP8). (Per-block) Detected ModelOpt fp8 checkpoint (quant_algo=FP8_PB_WO).__init__.py: Selected CutlassFP8ScaledMMLinearKernel for ModelOptFp8LinearMethod(Per-tensor only) — the ModelOpt FP8 kernel selectedVisual comparison -- Wan2.2-I2V-A14B
BF16 baseline:
wan22_A14B_bf16.mp4
ModelOpt FP8 per-tensor (this PR):
wan22_A14B_fp8_per_tensor.mp4
ModelOpt FP8 per-block (this PR)
wan22_A14B_fp8_per_block.mp4
Same prompt (
"A skateboarder in a purple bomber jacket doing a kickflip in a foggy urban plaza, overcast morning light, slow motion, european architecture in the background."), same seed, same sampling params. Output is BF16-equivalentValidation --- HunyuanVideo-1.5 720p (1x H100 80GB, T2V 720x1280, 49 frames, 30 steps, seed=42)
torch.compileenabled (default).--vae-use-tilingis set during benchmarking as there will be CUDA OOM on both BF16 baseline and ModelOpt FP8 if not enable.BF16 baseline v.s. Per-tensor quantization
BF16 baseline v.s. Per-block quantization
Engine signals confirming the path is wired correctly:
factory.py: Building quantization config: fp8→Building quantization config: modelopt— auto-detect upgraded the user's--quantization fp8flag to ModelOpt based onquant_algo: FP8orquant_algo: FP8_PB_WOin transformer/config.jsondata.py: Auto-detected quantization 'modelopt' from model configmodelopt.py:381(Per-tensor) Detected ModelOpt fp8 checkpoint (quant_algo=FP8). (Per-block) Detected ModelOpt fp8 checkpoint (quant_algo=FP8_PB_WO).__init__.py: Selected CutlassFP8ScaledMMLinearKernel for ModelOptFp8LinearMethod(Per-tensor only) — the ModelOpt FP8 kernel selectedVisual comparison -- HunyuanVideo-1.5 720p
BF16 baseline:
hunyuan_720p_bf16.mp4
ModelOpt FP8 per-tensor (this PR):
hunyuan_720p_fp8_per_tensor.mp4
ModelOpt FP8 per-block (this PR)
hunyuan_720p_fp8_per_block.mp4
Same prompt (
"An astronaut in a white spacesuit riding a horse across the lunar surface, gray dust kicked up by the horse's hooves, Earth visible in the black sky, lunar lander in the distance, cinematic wide shot. Make sure the astronaut is really moving!") and negative prompt (""vibrant colors, overexposed, static, blurred details, subtitles, style, artwork, painting, picture, still, overall gray, worst quality, low quality, JPEG compression artifacts, ugly, mutilated, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fused fingers, still frame, cluttered background, three legs, many people in the background, walking backwards"") same seed, same sampling params. Output is BF16-equivalentValidation --- Wan2.1-VACE-14B (1x H100 80GB, R2V 480x832, 49 frames, 30 steps, seed=42)
torch.compileenabled (default).BF16 baseline v.s. Per-tensor quantization
BF16 baseline v.s. Per-block quantization
Engine signals confirming the path is wired correctly:
factory.py: Building quantization config: fp8→Building quantization config: modelopt— auto-detect upgraded the user's--quantization fp8flag to ModelOpt based onquant_algo: FP8orquant_algo: FP8_PB_WOin transformer/config.jsondata.py: Auto-detected quantization 'modelopt' from model configmodelopt.py:381(Per-tensor) Detected ModelOpt fp8 checkpoint (quant_algo=FP8). (Per-block) Detected ModelOpt fp8 checkpoint (quant_algo=FP8_PB_WO).__init__.py: Selected CutlassFP8ScaledMMLinearKernel for ModelOptFp8LinearMethod(Per-tensor only) — the ModelOpt FP8 kernel selectedVisual comparison -- Wan2.1-VACE-A14B
BF16 baseline:
r2v_output_bf16.mp4
ModelOpt FP8 per-tensor (this PR):
r2v_output_fp8_per_tensor.mp4
ModelOpt FP8 per-block (this PR)
r2v_output_fp8_per_block.mp4
Same prompt (
"An astronaut in a white spacesuit riding a horse across the lunar surface, gray dust kicked up by the horse's hooves, Earth visible in the black sky, lunar lander in the distance, cinematic wide shot. Make sure the astronaut is really moving!") , same seed, same sampling params. Output is BF16-equivalent.Test Plan
Wan2.2-I2V-A14B
quant_algo: FP8for pre-tensor andquant_algo: FP8_PB_WOfor per-blockAuto-detected quantization 'modelopt')** HunyuanVideo-1.5 720p**
quant_algo: FP8for pre-tensor andquant_algo: FP8_PB_WOfor per-blockAuto-detected quantization 'modelopt')** Wan2.1-VACE-14B**
quant_algo: FP8for pre-tensor andquant_algo: FP8_PB_WOfor per-blockAuto-detected quantization 'modelopt')How to use
Pre calibrated checkpoints are published at huggingface:
Option A: use public checkpoint with no calibration needed
Option B: calibrate from BF16 yourself (reproducibility / custom prompts)
Known limitations
vae-use-tiling, wall-time speedup and throughput improvement are nearly negligibleFollow-ups
vllm-project-org/