[Quantization] Add FP8 support for Wan 2.2 transformer and Qwen Image VAE/text encoder#1412
[Quantization] Add FP8 support for Wan 2.2 transformer and Qwen Image VAE/text encoder#1412lishunyang12 wants to merge 1 commit into
Conversation
8d0734d to
d66e358
Compare
SamitHuang
left a comment
There was a problem hiding this comment.
this PR is clear. it should be ready to merge after checking the visual quality of quantization.
| type=str, | ||
| default=None, | ||
| choices=["fp8"], | ||
| help="Quantization method for the transformer. " |
There was a problem hiding this comment.
how about text encoder?
There was a problem hiding this comment.
Hii, thanks for review.
The text encoder (UMT5) is not quantized here — same as
what Z-Image does. Only the diffusion transformer layers
get FP8. The text encoder is relatively small compared to
the transformer, so quantizing it has less impact on
memory while potentially hurting prompt embedding quality.
We could add text encoder quantization as a follow-up.
|
@vllm-omni-reviewer |
🤖 VLLM-Omni PR ReviewCode Review: FP8 Quantization Support for Wan 2.2 Transformer1. OverviewThis PR adds FP8 quantization support for the Wan 2.2 video transformer by threading Overall Assessment: Positive - The changes are clean, consistent, and follow existing patterns. A few minor suggestions for robustness. 2. Code QualityStrengths
Potential Issues1. Inconsistent API usage in example scripts ( if args.quantization and ignored_layers:
quant_kwargs["quantization_config"] = {
"method": args.quantization,
"ignored_layers": ignored_layers,
}
elif args.quantization:
quant_kwargs["quantization"] = args.quantizationThis uses different keys ( 2. Variable reference in print statement ( if ignored_layers:
print(f" Ignored layers: {ignored_layers}")This correctly references 3. Architecture & DesignStrengths
Design Considerations1. Layer exclusion pattern: The PR correctly notes that 2. Missing 4. Security & SafetyNo significant security concerns. The changes are purely additive and don't introduce new attack vectors. Minor consideration: The 5. Testing & DocumentationTest Plan AssessmentThe test plan in the PR description is adequate but could be more comprehensive: Suggested additions:
Documentation
6. Specific Suggestions
|
|
@hsliuustc0106 @SamitHuang Please help check it as i uploaded the test results. Thx |
|
Regarding the GLM-5's suggestions:
|
|
This PR adds FP8 W8A8 quantization support for Wan 2.2 video transformer, enabling significant memory reduction on Ada/Hopper GPUs. The implementation follows the established Z-Image pattern consistently, threading quant_config through all 6 transformer classes and their parallel linear layers. The changes are well-structured, properly scoped (excluding text encoder and normalization layers as expected), and include comprehensive CLI support. The author has provided test results and addressed review feedback thoroughly. |
| import math | ||
| from collections.abc import Iterable | ||
| from typing import Any | ||
| from typing import TYPE_CHECKING, Any |
There was a problem hiding this comment.
Good use of TYPE_CHECKING to avoid runtime import overhead while maintaining type hints for QuantizationConfig. This keeps the quantization dependency optional at runtime.
| from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader | ||
| from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler | ||
| from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel | ||
| from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers |
There was a problem hiding this comment.
Proper integration with the existing quantization infrastructure. get_vllm_quant_config_for_layers handles the ignored_layers filtering and config validation.
| if load_transformer_2: | ||
| transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) | ||
| self.transformer_2 = create_transformer_from_config(transformer_2_config) | ||
| self.transformer_2 = create_transformer_from_config(transformer_2_config, quant_config=quant_config) |
There was a problem hiding this comment.
Important: both transformer and transformer_2 receive the same quant_config, ensuring consistent quantization across the dual-transformer architecture.
| help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", | ||
| ) | ||
| parser.add_argument( | ||
| "--quantization", |
There was a problem hiding this comment.
The CLI interface is well-designed with clear help text. The --quantization and --ignored-layers args provide flexibility for users to experiment with different quantization strategies.
| # Check if profiling is requested via environment variable | ||
| profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) | ||
|
|
||
| # Build quantization kwargs |
There was a problem hiding this comment.
The quantization_config dict construction properly handles both the quantization method and ignored_layers, matching the OmniDiffusionConfig expectations.
|
|
||
|
|
||
| def create_transformer_from_config(config: dict) -> WanTransformer3DModel: | ||
| def create_transformer_from_config(config: dict, quant_config=None) -> WanTransformer3DModel: |
There was a problem hiding this comment.
Missing type annotation for quant_config parameter. Should be quant_config: QuantizationConfig | None = None to match the pattern used in the transformer classes and maintain type safety.
There was a problem hiding this comment.
Fixed in the latest push. Added TYPE_CHECKING import and QuantizationConfig type annotation to match the pattern in wan2_2_transformer.py:
# Before
def create_transformer_from_config(config: dict, quant_config=None) -> WanTransformer3DModel:
# After
def create_transformer_from_config(
config: dict, quant_config: "QuantizationConfig | None" = None
) -> WanTransformer3DModel:|
examples/offline_inference/text_to_video/text_to_video.py:97 Critical: This PR adds quantization support but provides no test coverage. We need tests to verify:
Without tests, we can't validate the 'significant memory reduction' claim or prevent regressions. |
|
Looking more critically at this PR, there are several concerns that should be addressed: Missing Test Coverage: This adds a significant feature (FP8 quantization) with zero test coverage. We need tests to validate:
Missing Performance Data: The PR claims "significant memory reduction" but the test results only show latency metrics, not actual memory usage. We need:
Type Safety: The Documentation: No documentation added explaining:
While the implementation follows the Z-Image pattern correctly, these gaps make it difficult to validate the feature works as intended and prevent future regressions. |
emmm, i already added the test to validate the memory reduction and show output quality consistency. False negative. |
|
@hsliuustc0106 lets just ignore AI comments as they are not valid. Doc should be provided in a separate PR. |
hsliuustc0106
left a comment
There was a problem hiding this comment.
Summary
This PR adds FP8 quantization support to Wan 2.2 transformers by threading quant_config through all parallel linear layers. The implementation follows the established Z-Image pattern and includes comprehensive test results showing ~41% memory reduction with minimal quality impact.
Pros:
- Clean, consistent implementation across all 6 transformer classes
- Follows established Z-Image pattern (commit b7604ae)
- Comprehensive test results with actual memory measurements and video outputs
- Proper use of TYPE_CHECKING to avoid circular imports
- Good CLI help text with examples
Cons:
- Inconsistent API usage in example scripts (two different ways to pass quantization config)
- Minor code duplication between text_to_video.py and image_to_video.py
Recommendation: Approve with minor suggestions for API consistency.
| "Example: --ignored-layers 'to_qkv,to_out'", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
There was a problem hiding this comment.
Issue: Inconsistent API usage
The code uses two different approaches depending on whether ignored_layers is provided:
- With ignored_layers:
quantization_configdict withmethodandignored_layers - Without: Simple
quantizationstring
This could be confusing. Consider unifying to always use the same format:
if args.quantization:
quant_kwargs["quantization_config"] = {
"method": args.quantization,
**(({"ignored_layers": ignored_layers} if ignored_layers else {}))
}Or verify that Omni handles both formats identically.
| "Example: --ignored-layers 'to_qkv,to_out'", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
There was a problem hiding this comment.
Issue: Same inconsistent API usage
Same concern as in text_to_video.py - consider unifying the quantization config format.
| @@ -28,6 +28,11 @@ | |||
| SequenceParallelOutput, | |||
| ) | |||
There was a problem hiding this comment.
Good practice: TYPE_CHECKING usage
Nice use of TYPE_CHECKING to avoid circular imports while maintaining type safety.
| @@ -92,14 +97,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
| class ColumnParallelGELU(nn.Module): | |||
| """Column parallel linear with GELU activation.""" | |||
There was a problem hiding this comment.
Suggestion: Add docstring
Consider adding a brief docstring explaining the quant_config parameter:
def __init__(
self,
dim_in: int,
dim_out: int,
*,
approximate: str = "tanh",
bias: bool = True,
quant_config: "QuantizationConfig | None" = None,
):
"""Column parallel linear with GELU activation.
Args:
quant_config: Optional quantization config for FP8/other methods.
"""| @@ -23,10 +23,16 @@ | |||
| from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader | |||
| from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler | |||
There was a problem hiding this comment.
Good: Consistent pattern
The quantization config extraction and threading follows the same pattern as Z-Image. This consistency makes the codebase easier to maintain.
| from vllm_omni.diffusion.models.schedulers import FlowUniPCMultistepScheduler | ||
| from vllm_omni.diffusion.models.wan2_2.wan2_2_transformer import WanTransformer3DModel | ||
| from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers | ||
| from vllm_omni.diffusion.request import OmniDiffusionRequest |
There was a problem hiding this comment.
Suggestion: Add logging
Consider adding a log message when quantization is enabled for better visibility:
quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config)
if quant_config is not None:
logger.info("Enabling quantization for Wan 2.2 transformer: %s", quant_config)|
Hi @lishunyang12 👋 This FP8 quantization PR hasn't been updated for 13 days. Is this still on your radar? Let us know if you need any support. Thanks! |
ea18a8d to
71a9035
Compare
ee61360 to
b9edf6b
Compare
… VAE/text encoder Signed-off-by: lishunyang <lishunyang12@163.com>
b9edf6b to
23fed75
Compare
Summary
This PR extends FP8 quantization support to two additional model families:
quant_configthrough all parallel linear layers (same pattern as Z-Image)nn.Linear,Conv2d,Conv3d)Subsumes #1414 (FP8 for Qwen Image VAE/encoder).
Wan 2.2 Changes
Wire
quant_configfrom pipelines through all parallel linear layers in the Wan 2.2 video transformer, following the same pattern established by Z-Image (commit b7604ae).wan2_2_transformer.pyquant_configparam to 6 classes (ColumnParallelGELU,WanFeedForward,WanSelfAttention,WanCrossAttention,WanTransformerBlock,WanTransformer3DModel) and pass to allColumnParallelLinear,RowParallelLinear,QKVParallelLinearlayerspipeline_wan2_2.pyquant_configviaget_vllm_quant_config_for_layersand pass to bothtransformerandtransformer_2pipeline_wan2_2_i2v.pypipeline_wan2_2_ti2v.pytext_to_video.py--quantizationand--ignored-layersargsimage_to_video.py--quantizationand--ignored-layersargsNot quantized (same as Z-Image pattern):
DistributedRMSNorm,Attention,Conv3dLayer,nn.Linear(proj_out),FP32LayerNorm, embedding layers.Qwen Image VAE/Encoder Changes
Add FP8 weight-only storage for Linear/Conv2d/Conv3d layers in the Qwen Image VAE and text encoder. Weights are stored in
float8_e4m3fnwith per-tensor scales and dequantized to BF16 before each forward pass — saving ~50% memory for these components.models/utils.pyapply_fp8_weight_storage()utility — quantizes weights, registers forward pre/post hooks for dequantpipeline_qwen_image.pypipeline_qwen_image_edit.pypipeline_qwen_image_edit_plus.pyWan 2.2 Test Results
T2V Pipeline (Wan2.2-T2V-A14B-Diffusers)
Environment: 1x GPU, 1280×720, 81 frames, 40 steps, seed=42
I2V Pipeline (Wan2.2-I2V-A14B-Diffusers)
Environment: 1x GPU, auto-resolution, 81 frames, 50 steps, seed=42
Test plan
--quantization fp8works end-to-end--quantization, behavior is identical (quant_config=None is no-op)wan22_fp8_quantized.mp4
wan22_fp8_ignored_layers.mp4
wan22_bf16_baseline.mp4
i2v_bf16_baseline.mp4
i2v_fp8_quantized.mp4