-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[Quantization] Add FP8 support for Wan 2.2 transformer and Qwen Image VAE/text encoder #1412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import os | ||
| import time | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -137,6 +138,23 @@ def parse_args() -> argparse.Namespace: | |
| action="store_true", | ||
| help="Enable diffusion pipeline profiler to display stage durations.", | ||
| ) | ||
| parser.add_argument( | ||
| "--quantization", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CLI interface is well-designed with clear help text. The --quantization and --ignored-layers args provide flexibility for users to experiment with different quantization strategies. |
||
| type=str, | ||
| default=None, | ||
| choices=["fp8"], | ||
| help="Quantization method for the transformer. " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about text encoder?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hii, thanks for review. |
||
| "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). " | ||
| "Default: None (no quantization, uses BF16).", | ||
| ) | ||
| parser.add_argument( | ||
| "--ignored-layers", | ||
| type=str, | ||
| default=None, | ||
| help="Comma-separated list of layer name patterns to skip quantization. " | ||
| "Only used when --quantization is set. " | ||
| "Example: --ignored-layers 'to_qkv,to_out'", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Inconsistent API usage The code uses two different approaches depending on whether
This could be confusing. Consider unifying to always use the same format: if args.quantization:
quant_kwargs["quantization_config"] = {
"method": args.quantization,
**(({"ignored_layers": ignored_layers} if ignored_layers else {}))
}Or verify that |
||
|
|
||
|
|
@@ -176,6 +194,17 @@ def main(): | |
| # Check if profiling is requested via environment variable | ||
| profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) | ||
|
|
||
| # Build quantization kwargs | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The quantization_config dict construction properly handles both the quantization method and ignored_layers, matching the OmniDiffusionConfig expectations. |
||
| quant_kwargs: dict[str, Any] = {} | ||
| ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None | ||
| if args.quantization and ignored_layers: | ||
| quant_kwargs["quantization_config"] = { | ||
| "method": args.quantization, | ||
| "ignored_layers": ignored_layers, | ||
| } | ||
| elif args.quantization: | ||
| quant_kwargs["quantization"] = args.quantization | ||
|
|
||
| omni = Omni( | ||
| model=args.model, | ||
| enable_layerwise_offload=args.enable_layerwise_offload, | ||
|
|
@@ -190,6 +219,7 @@ def main(): | |
| cache_backend=args.cache_backend, | ||
| cache_config=cache_config, | ||
| enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, | ||
| **quant_kwargs, | ||
| ) | ||
|
|
||
| if profiler_enabled: | ||
|
|
@@ -207,6 +237,9 @@ def main(): | |
| f" cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}," | ||
| f" vae_patch_parallel_size={args.vae_patch_parallel_size}, enable_expert_parallel={args.enable_expert_parallel}" | ||
| ) | ||
| print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}") | ||
| if ignored_layers: | ||
| print(f" Ignored layers: {ignored_layers}") | ||
| print(f" Video size: {args.width}x{args.height}") | ||
| print(f"{'=' * 60}\n") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Utility functions for diffusion models.""" | ||
|
|
||
| import logging | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Maximum value for float8_e4m3fn | ||
| _FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max | ||
|
|
||
|
|
||
| _FP8_TARGET_LAYERS = (nn.Linear, nn.Conv2d, nn.Conv3d) | ||
|
|
||
|
|
||
| def apply_fp8_weight_storage(model: nn.Module) -> None: | ||
| """Apply FP8 weight-only storage to Linear/Conv2d/Conv3d layers. | ||
|
|
||
| Stores weights in float8_e4m3fn with per-tensor scales. | ||
| Dequantizes to the original compute dtype before each forward pass, | ||
| then re-quantizes afterward to free BF16 memory. | ||
|
|
||
| This saves ~50% of memory with no accuracy loss since computation | ||
| still happens in the original dtype. | ||
|
|
||
| Args: | ||
| model: The model whose layers will be quantized. | ||
| """ | ||
| count = 0 | ||
| for name, module in model.named_modules(): | ||
| if not isinstance(module, _FP8_TARGET_LAYERS): | ||
| continue | ||
|
|
||
| # P4: Idempotency guard -- skip if already quantized | ||
| if hasattr(module, "_fp8_weight"): | ||
| continue | ||
|
|
||
| weight = module.weight.data | ||
| compute_dtype = weight.dtype | ||
|
|
||
| # Compute per-tensor scale | ||
| amax = weight.abs().amax().clamp(min=1e-12) | ||
| scale = amax / _FP8_E4M3_MAX | ||
|
|
||
| # Quantize weight to FP8 | ||
| fp8_weight = (weight / scale).clamp(min=-_FP8_E4M3_MAX, max=_FP8_E4M3_MAX).to(torch.float8_e4m3fn) | ||
|
|
||
| # Store FP8 weight and metadata as buffers (not parameters) | ||
| module.register_buffer("_fp8_weight", fp8_weight) | ||
| module.register_buffer("_fp8_scale", scale.to(torch.float32)) | ||
| module._fp8_compute_dtype = compute_dtype | ||
|
|
||
| # P1: Keep the parameter at the original compute dtype so that | ||
| # model.dtype (derived from parameters) stays correct. We store | ||
| # the FP8 representation in the _fp8_weight buffer and only | ||
| # dequantize into module.weight.data inside the pre-hook. | ||
| # After forward, the post-hook swaps back to FP8 storage to | ||
| # free the BF16/FP16 memory. | ||
| module.weight.data = fp8_weight.to(compute_dtype) | ||
|
|
||
| def _pre_hook(mod, args): | ||
| # Dequantize: restore BF16/FP16 weight for computation | ||
| # P2: Cast back to compute dtype to avoid float32 promotion | ||
| # from the float32 scale tensor. | ||
| mod.weight.data = (mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale).to(mod._fp8_compute_dtype) | ||
|
|
||
| def _post_hook(mod, args, output): | ||
| # Re-quantize: swap back to FP8-dequantized placeholder to | ||
| # free full-precision memory while keeping dtype correct. | ||
| mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype) | ||
|
|
||
| module.register_forward_pre_hook(_pre_hook) | ||
| module.register_forward_hook(_post_hook) | ||
| count += 1 | ||
|
|
||
| logger.info( | ||
| "Applied FP8 weight storage to %d layers in %s", | ||
| count, | ||
| model.__class__.__name__, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: Same inconsistent API usage
Same concern as in text_to_video.py - consider unifying the quantization config format.