diff --git a/examples/offline_inference/vace/vace_video_generation.py b/examples/offline_inference/vace/vace_video_generation.py index 6ca0d74c52e..e183e0d5b58 100644 --- a/examples/offline_inference/vace/vace_video_generation.py +++ b/examples/offline_inference/vace/vace_video_generation.py @@ -52,6 +52,13 @@ def parse_args() -> argparse.Namespace: choices=["t2v", "i2v", "v2lf", "flf2v", "inpaint", "r2v"], help="Generation mode.", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + choices=["fp8", "gguf"], + help="Quantization method for the transformer (fp8 for online FP8 quantization).", + ) parser.add_argument("--prompt", default="A cat walking in a garden", help="Text prompt.") parser.add_argument("--negative-prompt", default="", help="Negative prompt.") parser.add_argument("--image", type=str, default=None, help="Input image path (for I2V, R2V, FLF2V, inpaint).") @@ -159,6 +166,7 @@ def main(): flow_shift=args.flow_shift, enforce_eager=args.enforce_eager, parallel_config=parallel_config, + quantization = args.quantization ) prompt_data = build_prompts(args) diff --git a/examples/quantization/check_modelopt_fp8_export.py b/examples/quantization/check_modelopt_fp8_export.py new file mode 100644 index 00000000000..bab1211a6a1 --- /dev/null +++ b/examples/quantization/check_modelopt_fp8_export.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Verify a ModelOpt FP8 diffusers checkpoint exported by +quantize_hunyuanvideo_15_modelopt_fp8.py (or any sibling quantize_*.py). + +Three checks: + A. transformer/config.json has a sane quantization_config block. + B. transformer/*.safetensors contains FP8 (float8_e4m3fn) tensors. + C. transformer disk size is materially smaller than a BF16 baseline. + +Example: + python examples/quantization/check_modelopt_fp8_export.py \\ + --output ./hv15-480p-modelopt-fp8 + + # Optional: compare disk size against a local or HF BF16 baseline. + python examples/quantization/check_modelopt_fp8_export.py \\ + --output ./hv15-480p-modelopt-fp8 \\ + --baseline hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v +""" + +from __future__ import annotations + +import argparse +import json +import sys +from collections import Counter +from pathlib import Path + + +def _check_config(transformer_dir: Path) -> int: + """Returns 0 on pass, 1 on fail. Prints findings.""" + cfg_path = transformer_dir / "config.json" + if not cfg_path.exists(): + print(f"[FAIL] {cfg_path} missing.") + return 1 + + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print(f"[FAIL] No `quantization_config` block in {cfg_path}.") + return 1 + + print(f"[A] quantization_config from {cfg_path}:") + print(json.dumps(qc, indent=2)) + + issues = [] + if qc.get("quant_method") != "modelopt": + issues.append(f"quant_method={qc.get('quant_method')!r} (expected 'modelopt')") + + quant_algo = qc.get("quant_algo") + if quant_algo not in ("FP8", "FP8_PB_WO"): + issues.append( + f"quant_algo={quant_algo!r} (expected 'FP8' for per-tensor or " + "'FP8_PB_WO' for 128x128 block-wise — other algos aren't routed by " + "vllm-omni's adapter today)" + ) + + # Cross-check that the saved weight strategy and the dispatch field agree. + # Producer scripts can in principle drift apart (e.g. metadata says "block" + # but quant_algo still claims "FP8"), and that lands as an AssertionError at + # weight load time because the runtime LinearMethod expects scalar scales but + # finds 4D block ones. Failing here is much friendlier. + cfg_groups = qc.get("config_groups", {}) + weight_strategies = { + (group or {}).get("weights", {}).get("strategy") + for group in cfg_groups.values() + if isinstance(group, dict) + } + weight_strategies.discard(None) + if weight_strategies == {"block"} and quant_algo != "FP8_PB_WO": + issues.append( + f"weights.strategy='block' but quant_algo={quant_algo!r}. Per-block " + "weight scales require FP8_PB_WO so upstream vLLM dispatches to " + "ModelOptFp8PbWoLinearMethod; FP8 routes to per-tensor and crashes " + "on the 4D weight_scale at weight load time." + ) + elif quant_algo == "FP8_PB_WO" and weight_strategies != {"block"}: + issues.append( + f"quant_algo='FP8_PB_WO' but weights.strategy={weight_strategies!r} " + "(expected {'block'}). FP8_PB_WO consumers expect 4D per-block scales." + ) + + if issues: + print("[A] WARN — config looks incomplete:") + for issue in issues: + print(f" - {issue}") + return 2 + print(f"[A] PASS — config looks correct (quant_algo={quant_algo}).") + return 0 + + +def _read_safetensors_header(path: Path) -> dict: + """Read the JSON header of a safetensors file. Bypass-safe — doesn't materialize tensors. + + Returns {tensor_name: {'dtype': 'F8_E4M3', 'shape': [...], 'data_offsets': [...]}}. + Header dtype strings: F8_E4M3, F8_E5M2, BF16, F16, F32, F64, I8, I16, I32, I64, BOOL, U8, ... + """ + import struct + + with open(path, "rb") as f: + header_len = struct.unpack(" str: + """Infer per-tensor vs per-channel vs per-block from sample weight_scale shapes. + + ModelOpt block-wise produces shapes like `[16, 1, 16, 1]` (broadcasting dims of 1 + interleaved with block-count dims). We count "meaningful" dims — ones with size > 1 — + and classify: 0 meaningful dims = per-tensor (scalar), 1 = per-channel, 2+ = per-block. + """ + if not weight_scale_shapes: + return "no weight_scale tensors found" + + def meaningful_dims(shape: list[int]) -> int: + return sum(1 for d in shape if d > 1) + + per_tensor = sum(1 for s in weight_scale_shapes if meaningful_dims(s) == 0) + per_channel = sum(1 for s in weight_scale_shapes if meaningful_dims(s) == 1) + per_block = sum(1 for s in weight_scale_shapes if meaningful_dims(s) >= 2) + total = len(weight_scale_shapes) + if per_tensor == total: + return "per-tensor (all scalar scales)" + if per_channel == total: + return "per-channel (1 meaningful dim)" + if per_block == total: + return "per-block (2+ meaningful dims — e.g. [M//bm, 1, N//bn, 1] for tiles)" + return f"mixed: per-tensor={per_tensor}, per-channel={per_channel}, per-block={per_block} of {total}" + + +def _check_safetensors(transformer_dir: Path) -> int: + """Returns 0 on pass, 1 on fail. Reads on-disk dtype from the safetensors header.""" + files = sorted(transformer_dir.glob("*.safetensors")) + if not files: + print(f"[FAIL] No *.safetensors in {transformer_dir}.") + return 1 + + header_dtype_counts: Counter[str] = Counter() + sample_fp8_keys: list[str] = [] + sample_scale_keys: list[str] = [] + weight_scale_shapes: list[list[int]] = [] + sample_weight_scale_entries: list[tuple[str, list[int]]] = [] + for f in files: + try: + header = _read_safetensors_header(f) + except Exception as exc: + print(f"[B] WARN — could not parse header of {f}: {exc}") + continue + for k, info in header.items(): + dtype = info.get("dtype", "?") + header_dtype_counts[dtype] += 1 + if dtype.startswith("F8") and len(sample_fp8_keys) < 5: + sample_fp8_keys.append(k) + if k.endswith(("_scale", ".weight_scale", ".input_scale", "_scale_inv")) and len(sample_scale_keys) < 5: + sample_scale_keys.append(k) + if k.endswith(".weight_scale"): + weight_scale_shapes.append(info.get("shape", [])) + if len(sample_weight_scale_entries) < 5: + sample_weight_scale_entries.append((k, info.get("shape", []))) + + print(f"\n[B] On-disk dtype counts across {len(files)} safetensors file(s) (from header, not get_tensor):") + for dtype, count in sorted(header_dtype_counts.items(), key=lambda kv: -kv[1]): + marker = " <-- FP8" if dtype.startswith("F8") else "" + print(f" {dtype:10s} {count:>6d}{marker}") + + fp8_count = sum(c for d, c in header_dtype_counts.items() if d.startswith("F8")) + if fp8_count == 0: + print("[B] FAIL — no FP8 tensors on disk. Calibration likely did not actually quantize the weights.") + return 1 + + print(f"[B] PASS — {fp8_count} FP8 tensors stored on disk.") + if sample_fp8_keys: + print(f" sample FP8 tensors: {sample_fp8_keys[:3]}") + if sample_scale_keys: + print(f" sample scale tensors: {sample_scale_keys[:3]}") + print(" (Note: torch's get_tensor() may return these as bf16 views on some versions —") + print(" irrelevant; vLLM's loader uses native FP8 ops.)") + + # Weight-scale granularity — per-tensor (scalar) vs per-channel (1-D) vs per-block (N-D). + print(f"\n weight_scale granularity: {_classify_weight_scale_granularity(weight_scale_shapes)}") + for key, shape in sample_weight_scale_entries[:3]: + print(f" {key}: shape {shape}") + return 0 + + +def _disk_size_gib(p: Path) -> float: + return sum(f.stat().st_size for f in p.rglob("*") if f.is_file()) / (1024**3) + + +def _transformer_subdirs(root: Path) -> list[Path]: + """Return [/transformer, /transformer_2] for those that exist. + + Wan2.2 MoE A14B (T2V/I2V) and Wan2.2-VACE-A14B export TWO transformer + subfolders; single-transformer checkpoints just have `transformer/`. + Falls back to `[root]` if neither exists (e.g., a baseline directory + that wasn't structured as a diffusers repo). + """ + found = [root / name for name in ("transformer", "transformer_2") if (root / name).is_dir()] + return found if found else [root] + + +def _check_size_vs_baseline(transformer_dir: Path, baseline: str | None) -> int: + """Returns 0 always (informational only).""" + # transformer_dir is /transformer; walk one level up so we can + # also pick up transformer_2/ for Wan2.2 MoE A14B checkpoints. + fp8_root = transformer_dir.parent + fp8_subdirs = _transformer_subdirs(fp8_root) + fp8_size = sum(_disk_size_gib(p) for p in fp8_subdirs) + fp8_label = " + ".join(p.name for p in fp8_subdirs) + print(f"\n[C] FP8 transformer disk size ({fp8_label}): {fp8_size:.2f} GiB") + + if baseline is None: + print("[C] SKIP — pass --baseline to compare against BF16.") + return 0 + + baseline_path = Path(baseline) + if not baseline_path.exists(): + # Treat `baseline` as an HF repo id and read from the local cache. + # Don't trigger a download: this script is meant to run AFTER + # quantize_*_modelopt_fp8.py, which already pulled the whole repo + # into the cache. local_files_only=True makes that assumption + # explicit — if the cache is empty we surface a clear error rather + # than silently kicking off a multi-GB download. + try: + from huggingface_hub import snapshot_download + from huggingface_hub.errors import LocalEntryNotFoundError + except ImportError: + print("[C] SKIP — huggingface_hub not installed and baseline not a local path.") + return 0 + try: + baseline_path = Path(snapshot_download(baseline, local_files_only=True)) + except LocalEntryNotFoundError: + print( + f"[C] SKIP — '{baseline}' not found in local HF cache. " + "Run the matching quantize_*_modelopt_fp8.py first (it caches the BF16 repo), " + "or pass --baseline ." + ) + return 0 + print(f" Resolved baseline from HF cache: {baseline_path}") + + bf16_subdirs = _transformer_subdirs(baseline_path) + bf16_size = sum(_disk_size_gib(p) for p in bf16_subdirs) + if bf16_size == 0: + print(f"[C] WARN — baseline transformer dir empty: {baseline_path}") + return 0 + + bf16_label = " + ".join(p.name for p in bf16_subdirs) + reduction = (1 - fp8_size / bf16_size) * 100 + print(f"[C] BF16 baseline transformer disk size ({bf16_label}): {bf16_size:.2f} GiB ({baseline_path})") + print(f"[C] Disk reduction: {reduction:.1f}% (FP8 is {fp8_size / bf16_size:.0%} of BF16)") + if reduction < 30: + print("[C] WARN — FP8 should typically reduce disk by ~40-50%; <30% suggests partial quantization.") + + # Whole-repo view: includes VAE / text_encoder / tokenizer / scheduler / + # top-level metadata. Quantization only touches transformer(s) so this + # reduction is always smaller than the transformer-only one — but it's + # what the deployment footprint actually is. + fp8_total = _disk_size_gib(fp8_root) + bf16_total = _disk_size_gib(baseline_path) + if bf16_total > 0: + total_reduction = (1 - fp8_total / bf16_total) * 100 + print( + f"[C] Whole-repo: FP8 {fp8_total:.2f} GiB / BF16 {bf16_total:.2f} GiB " + f"(reduction {total_reduction:.1f}%, deployment footprint)" + ) + return 0 + + +def main() -> None: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--output", required=True, help="Path to the exported ModelOpt FP8 checkpoint root.") + p.add_argument( + "--baseline", + default=None, + help="Optional BF16 baseline (local diffusers dir or HF id) for disk-size comparison.", + ) + args = p.parse_args() + + out_root = Path(args.output).expanduser().resolve() + transformer_dir = out_root / "transformer" + if not transformer_dir.exists(): + print(f"[FAIL] {transformer_dir} does not exist.") + sys.exit(1) + + print(f"Checking: {out_root}\n") + + fail = 0 + fail |= _check_config(transformer_dir) + fail |= _check_safetensors(transformer_dir) + _check_size_vs_baseline(transformer_dir, args.baseline) + + print() + if fail == 0: + print("=" * 60) + print("ALL CHECKS PASSED — checkpoint looks ready for vllm-omni serving.") + elif fail == 1: + print("=" * 60) + print("FAILURES detected — calibration may need to be re-run.") + sys.exit(1) + else: + print("=" * 60) + print("WARNINGS only — checkpoint may serve but with caveats. See [A] above.") + + +if __name__ == "__main__": + main() diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py new file mode 100644 index 00000000000..0c53efff631 --- /dev/null +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -0,0 +1,586 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantize HunyuanVideo-1.5 to a ModelOpt FP8 Hugging Face checkpoint. + +Calibrates the DiT transformer using a small video prompt set and exports a +diffusers-style directory whose transformer carries ModelOpt FP8 metadata. +The exported checkpoint is consumable by vllm-omni's ModelOpt FP8 adapter +(see vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt.py). + +Layers kept full precision match the #2728 / #2795 pattern: modulation, +AdaLayerNorm, entry/exit projections, embeddings, the token refiner path, +and final proj_out. MHA quantizers are off by default; HV-1.5 self-attention +empirically degrades under FP8 (see #2920 ablation). + +Supported targets (T2V uses HunyuanVideo15Pipeline; I2V uses +HunyuanVideo15ImageToVideoPipeline. `--variant auto` detects from the loaded +class, but you can pin it with `--variant t2v|i2v`.): +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v` +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v` +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_i2v` +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_i2v` + +For I2V variants, diffusers' HunyuanVideo15ImageToVideoPipeline takes a +required `image` kwarg (and derives height/width from the image), so +calibration must pair every prompt with a reference image — pass +`--reference-images `. + +Recommended resolutions per variant (CLI overrides accepted; T2V uses these +defaults, I2V derives from the reference image and ignores --height/--width): +- 480p: --height 480 --width 832 (default) +- 720p: --height 720 --width 1280 + +Example (480p T2V): + python examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py \ + --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v \ + --output ./hv15-480p-t2v-modelopt-fp8 \ + --overwrite + +Example (480p I2V): + python examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py \ + --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_i2v \ + --variant i2v \ + --reference-images /path/to/ref_images \ + --output ./hv15-480p-i2v-modelopt-fp8 \ + --overwrite +""" + +from __future__ import annotations + +import argparse +import copy +import json +import re +import shutil +import sys +from pathlib import Path +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +DEFAULT_PROMPTS = [ + "A dog running across a field of golden wheat.", + "An astronaut riding a horse across the surface of Mars, red dust swirling, cinematic wide shot.", + "A hummingbird hovering in front of a vibrant red flower, slow motion, macro shot.", + "A crackling campfire at night under a starry sky, sparks rising into the dark.", + "An underwater shot of a coral reef with tropical fish swimming by, sun rays piercing the water.", + "A close-up of a blooming rose covered in morning dew, soft natural light.", + "A peaceful mountain village at dawn, mist rolling over the rooftops, cinematic establishing shot.", + "A skateboarder doing a kickflip in an urban plaza, slow motion, golden hour lighting.", +] + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--model", required=True, help="Input HV-1.5 diffusers directory or HF id.") + p.add_argument("--output", required=True, help="Output directory for the ModelOpt FP8 checkpoint.") + p.add_argument("--dtype", choices=("bfloat16", "float16"), default="bfloat16") + p.add_argument("--height", type=int, default=480) + p.add_argument("--width", type=int, default=832) + p.add_argument( + "--num-frames", + type=int, + default=33, + help="Frames per calibration sample. 33 matches the typical short benchmark.", + ) + p.add_argument("--guidance-scale", type=float, default=6.0) + p.add_argument( + "--calib-steps", + type=int, + default=10, + help="Denoising steps per calibration prompt (10 is enough for amax statistics).", + ) + p.add_argument("--calib-size", type=int, default=8, help="How many prompts to use for calibration.") + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--prompt", + action="append", + default=[], + help="Custom calibration prompt. Repeat to provide multiple.", + ) + p.add_argument( + "--variant", + choices=("auto", "t2v", "i2v"), + default="auto", + help="HunyuanVideo-1.5 pipeline variant. `auto` detects from the loaded pipeline class " + "(HunyuanVideo15Pipeline -> t2v, HunyuanVideo15ImageToVideoPipeline -> i2v). " + "Pass `i2v` only if you also pass --reference-images.", + ) + p.add_argument( + "--reference-images", + type=str, + default=None, + help="Required for i2v variants. Directory of jpg/jpeg/png/webp files (or a single image). " + "Every calibration sample is paired with a cycled ref image since `image` is a required " + "kwarg, not optional, in HunyuanVideo15ImageToVideoPipeline. The pipeline derives " + "height/width from the image, so --height/--width are ignored under i2v.", + ) + p.add_argument( + "--quantize-mha", + action="store_true", + help="Enable FP8 attention K/V/softmax quantizers. Off by default — empirically degrades HV-1.5 video output.", + ) + p.add_argument( + "--weight-block-size", + type=str, + default=None, + help="Per-block weight quantization as 'M,N'. Only '128,128' is accepted because upstream " + "vLLM's ModelOptFp8PbWoLinearMethod hardcodes that block shape. Default: per-tensor. " + "Block-wise saves checkpoints with FP8_PB_WO routing (per-block static weights + per-token-" + "group dynamic activations); per-tensor uses static FP8 with calibrated activation scales.", + ) + p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") + return p + + +def _parse_block_size(spec: str | None) -> list[int] | None: + if spec is None: + return None + parts = [int(x) for x in spec.split(",") if x.strip()] + if len(parts) != 2: + raise SystemExit(f"--weight-block-size must be 'M,N' (2 ints), got {spec!r}") + return parts + + +def _require_modelopt() -> Any: + try: + import modelopt.torch.quantization as mtq + except ModuleNotFoundError as exc: + raise SystemExit( + "NVIDIA ModelOpt is not installed. Install with:\n" + " pip install 'nvidia-modelopt[all]'\n" + f"Original error: {exc}" + ) from exc + return mtq + + +def _ensure_paths(args: argparse.Namespace) -> tuple[str, Path]: + model_path = args.model + output_dir = Path(args.output).expanduser().resolve() + if output_dir.exists(): + if not args.overwrite: + raise SystemExit(f"Output directory already exists: {output_dir}\nPass --overwrite to replace it.") + shutil.rmtree(output_dir) + return model_path, output_dir + + +def _select_dtype(name: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float16": torch.float16}[name] + + +def _build_prompts(args: argparse.Namespace) -> list[str]: + prompts = args.prompt or DEFAULT_PROMPTS + if args.calib_size <= 0: + raise SystemExit("--calib-size must be positive.") + if len(prompts) < args.calib_size: + repeats = (args.calib_size + len(prompts) - 1) // len(prompts) + prompts = (prompts * repeats)[: args.calib_size] + return prompts[: args.calib_size] + + +def _load_reference_images(spec: str | None) -> list[Any]: + """Load PIL.Image list from a directory or a single file path.""" + if spec is None: + return [] + from PIL import Image + + p = Path(spec).expanduser() + if not p.exists(): + raise SystemExit(f"--reference-images path not found: {p}") + if p.is_file(): + return [Image.open(p).convert("RGB")] + image_paths = sorted( + f for f in p.iterdir() if f.is_file() and f.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") + ) + if not image_paths: + raise SystemExit(f"No image files (jpg/jpeg/png/webp) found in {p}") + return [Image.open(f).convert("RGB") for f in image_paths] + + +def _resolve_variant(pipe: DiffusionPipeline, requested: str) -> str: + """Resolve --variant auto by inspecting the loaded pipeline class. + + HunyuanVideo15ImageToVideoPipeline -> i2v + HunyuanVideo15Pipeline (or anything else with no `image` kwarg) -> t2v + """ + if requested != "auto": + return requested + cls_name = pipe.__class__.__name__ + if "ImageToVideo" in cls_name: + return "i2v" + return "t2v" + + +def _build_calib_samples(prompts: list[str], variant: str, ref_images: list[Any]) -> list[tuple[str, Any]]: + """Pair each calibration prompt with a ref image (i2v) or None (t2v).""" + if variant == "i2v": + # ref_images is guaranteed non-empty by main()'s validation. + return [(prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(prompts)] + return [(prompt, None) for prompt in prompts] + + +# Layers to KEEP at full precision (mirror of the #2920 wiring + #2728/#2795 skip pattern). +# - x_embedder, image_embedder, context_embedder*, time_embed*, cond_type_embed: entry/embedding +# - norm_out, norm1*.linear, norm1_context*.linear, norm2*, norm2_context*: AdaLayerNorm modulation +# - proj_out: final output projection +# - token_refiner*: text-encoder refinement uses diffusers raw nn.Linear +def _filter_func_hv15(name: str) -> bool: + pattern = re.compile( + r"(proj_out.*|" + r".*(x_embedder|image_embedder|context_embedder|context_embedder_2|" + r"time_embed|cond_type_embed|" + r"norm_out|norm1\.linear|norm1_context\.linear|norm2|norm2_context|" + r"token_refiner).*)" + ) + return pattern.match(name) is not None + + +def _mha_filter_func(name: str) -> bool: + pattern = re.compile( + r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer|bmm2_output_quantizer).*" + ) + return pattern.match(name) is not None + + +def _disable_known_problematic_quantizers(mtq: Any, backbone: torch.nn.Module, *, quantize_mha: bool) -> None: + if not hasattr(mtq, "disable_quantizer"): + return + mtq.disable_quantizer(backbone, _filter_func_hv15) + if not quantize_mha: + mtq.disable_quantizer(backbone, _mha_filter_func) + + +def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: + pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) + if hasattr(pipe, "set_progress_bar_config"): + pipe.set_progress_bar_config(disable=True) + pipe.to("cuda") + return pipe + + +def _build_forward_loop( + pipe: DiffusionPipeline, + args: argparse.Namespace, + samples: list[tuple[str, Any]], + variant: str, +): + """Build a forward_loop over (prompt, ref_image) calibration samples. + + For i2v: HunyuanVideo15ImageToVideoPipeline derives height/width from the + image, so we pass `image=` and drop --height/--width. For t2v: standard + prompt-only path with --height/--width honored. + """ + generator = torch.Generator(device="cuda") + + # Try to set guidance on the pipeline's guider object up front (modern + # diffusers HV-1.5 uses a Guider abstraction, not a per-call kwarg). Falls + # back silently — calibration uses whatever default the pipeline ships with. + guider = getattr(pipe, "guider", None) + if guider is not None and hasattr(guider, "guidance_scale"): + try: + guider.guidance_scale = args.guidance_scale + except Exception: + pass + + base_kwargs: dict[str, Any] = dict( + num_frames=args.num_frames, + num_inference_steps=args.calib_steps, + output_type="latent", + ) + if variant != "i2v": + # I2V pipeline derives height/width from the input image and rejects + # these kwargs; only set them on T2V. + base_kwargs["height"] = args.height + base_kwargs["width"] = args.width + + def forward_loop(*_unused_args, **_unused_kwargs) -> None: + with torch.inference_mode(): + for idx, (prompt, ref_image) in enumerate(samples): + generator.manual_seed(args.seed + idx) + kwargs = dict(base_kwargs) + if ref_image is not None: + kwargs["image"] = ref_image + # Try with guidance_scale first; fall back without on TypeError + # for pipelines (like HV-1.5) that take CFG via guider config. + try: + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **kwargs) + except TypeError as exc: + if "guidance_scale" not in str(exc): + raise + pipe(prompt=prompt, generator=generator, **kwargs) + + return forward_loop + + +def _summarize_export(output_dir: Path) -> None: + cfg_path = output_dir / "transformer" / "config.json" + if not cfg_path.exists(): + print(f"[warn] {cfg_path} missing.", file=sys.stderr) + return + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print("[warn] No quantization_config in transformer/config.json.", file=sys.stderr) + return + print("Export summary:") + print(f" quant_method: {qc.get('quant_method')}") + print(f" quant_algo: {qc.get('quant_algo')}") + producer = qc.get("producer") + if isinstance(producer, dict): + print(f" producer: {producer.get('name')} {producer.get('version')}") + print(f" config path: {cfg_path}") + + +def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtype) -> int: + """Convert in-memory weights of quantized modules to actual FP8 storage. + + `export_hf_checkpoint` skips this step for unknown model types (HV-1.5 isn't + in ModelOpt's recognized-model registry), so we must call the per-weight + export helper ourselves. Same workaround as the HunyuanImage-3 calibration + helper. + """ + from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + get_quantization_format, + quantizer_attr_names, + weight_attr_names, + ) + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + exported = 0 + for name, module in backbone.named_modules(): + try: + quantization_format = get_quantization_format(module) + except Exception as exc: + print(f"[warn] Could not inspect quantization format for {name}: {exc}", file=sys.stderr) + continue + if quantization_format == QUANTIZATION_NONE: + continue + for weight_name in weight_attr_names(module): + quantizer_attrs = quantizer_attr_names(weight_name) + weight_quantizer = getattr(module, quantizer_attrs.weight_quantizer, None) + if weight_quantizer is None or not getattr(weight_quantizer, "is_enabled", False): + continue + _export_quantized_weight(module, dtype, weight_name) + exported += 1 + return exported + + +def _hv15_quant_config_block(weight_block_size: list[int] | None = None) -> dict: + """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). + + For per-block weight quantization,upstream's FP8_PB_WO hardcodes _WEIGHT_BLOCK_SIZE = (128, 128), so any other + block shape produces a checkpoint vLLM cannot serve. + """ + if weight_block_size is not None and tuple(weight_block_size) != (128, 128): + raise ValueError( + f"--weight-block-size {tuple(weight_block_size)} not supported: upstream vLLM's " + "ModelOptFp8PbWoLinearMethod hardcodes (128, 128). Pass '128,128' or omit the flag." + ) + + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} + if weight_block_size is not None: + weights_cfg["strategy"] = "block" + weights_cfg["block_structure"] = f"{weight_block_size[0]}x{weight_block_size[1]}" + return { + "config_groups": { + "group_0": { + "input_activations": {"dynamic": False, "num_bits": 8, "type": "float"}, + "weights": weights_cfg, + "targets": ["Linear"], + } + }, + "ignore": [ + "context_embedder*", + "context_embedder_2*", + "cond_type_embed*", + "image_embedder*", + "norm1.linear*", + "norm1_context.linear*", + "norm2*", + "norm2_context*", + "norm_out*", + "proj_out*", + "time_embed*", + "token_refiner*", + "x_embedder*", + ], + "producer": {"name": "modelopt"}, + "quant_algo": "FP8_PB_WO" if weight_block_size is not None else "FP8", + "quant_method": "modelopt", + } + + +def _patch_quant_config(output_dir: Path, weight_block_size: list[int] | None = None) -> None: + """Inject quant_algo: FP8 + config_groups into transformer/config.json so + vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8.""" + cfg_path = output_dir / "transformer" / "config.json" + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + new_qc = _hv15_quant_config_block(weight_block_size=weight_block_size) + existing = cfg.get("quantization_config") + if isinstance(existing, dict): + producer = existing.get("producer") + if isinstance(producer, dict): + new_qc["producer"] = producer + + cfg["quantization_config"] = new_qc + with cfg_path.open("w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2) + + +def _save_pipeline_with_fp8_transformer( + pipe: DiffusionPipeline, + model_path: str, + output_dir: Path, + max_shard_size: str = "5GB", +) -> None: + """Save the pipeline with the (now FP8) transformer. + + Copies the source directory verbatim except for `transformer/`, then + saves the transformer with quantizers hidden so the state dict contains + only the FP8 weights + scale tensors. + """ + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict + + src = Path(model_path) + if not src.exists(): + from huggingface_hub import snapshot_download + + src = Path(snapshot_download(model_path)) + + if output_dir.exists(): + shutil.rmtree(output_dir) + shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer")) + + transformer_out = output_dir / "transformer" + # `hide_quantizers_from_state_dict` walks named_modules(); pass the actual + # nn.Module (transformer), not the diffusers Pipeline wrapper. + with hide_quantizers_from_state_dict(pipe.transformer): + pipe.transformer.save_pretrained( + str(transformer_out), + safe_serialization=True, + max_shard_size=max_shard_size, + ) + + +def main() -> None: + args = _build_parser().parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for ModelOpt FP8 quantization.") + + mtq = _require_modelopt() + model_path, output_dir = _ensure_paths(args) + dtype = _select_dtype(args.dtype) + prompts = _build_prompts(args) + weight_block_size = _parse_block_size(args.weight_block_size) + + if args.reference_images is not None and args.variant == "t2v": + raise SystemExit("--reference-images is only meaningful with --variant i2v (or auto-detected i2v).") + + pipe = _load_pipeline(model_path, dtype) + variant = _resolve_variant(pipe, args.variant) + if variant == "i2v" and args.reference_images is None: + raise SystemExit( + "i2v variant requires --reference-images: HunyuanVideo15ImageToVideoPipeline " + "takes a required `image` kwarg, so calibration must pair every prompt with a " + "reference image." + ) + ref_images = _load_reference_images(args.reference_images) if variant == "i2v" else [] + samples = _build_calib_samples(prompts, variant, ref_images) + sample_label = f"i2v={len(samples)}" if variant == "i2v" else f"t2v={len(samples)}" + + print("Quantization plan:") + print(f" input: {args.model}") + print(f" output: {output_dir}") + print(f" dtype: {dtype}") + print(f" variant: {variant} (requested={args.variant}, class={pipe.__class__.__name__})") + if variant == "i2v": + print(" height/width: derived from reference image (i2v ignores --height/--width)") + print(f" reference imgs: {len(ref_images)}") + else: + print(f" height/width: {args.height}x{args.width}") + print(f" num_frames: {args.num_frames}") + print(f" calib_size: {len(samples)} ({sample_label})") + print(f" calib_steps: {args.calib_steps}") + print(f" quantize_mha: {args.quantize_mha}") + print( + f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" + ) + + backbone = pipe.transformer + + quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + if weight_block_size is not None: + # Switch from per-tensor (default) to block-wise weight quantization. + # ModelOpt's wildcard "*weight_quantizer" matches every linear's weight quantizer. + quant_config["quant_cfg"]["*weight_quantizer"] = { + "num_bits": (4, 3), # E4M3 (FP8 weights, same as default) + "block_sizes": {-1: weight_block_size[1], -2: weight_block_size[0]}, + } + print( + f" -> overriding weight quantizer with block_sizes={weight_block_size} " + f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" + ) + + forward_loop = _build_forward_loop(pipe, args, samples, variant) + quantized = mtq.quantize(backbone, quant_config, forward_loop) + if quantized is not None: + pipe.transformer = quantized + backbone = quantized + + _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=args.quantize_mha) + + print("\nForcing FP8 weight serialization (HV-1.5 isn't in ModelOpt's recognized-model registry,") + print("so we have to call the per-weight export helper ourselves)...") + exported = _force_export_quantized_weights(backbone, dtype) + print(f" -> {exported} weights converted to FP8 in memory") + if exported == 0: + raise SystemExit( + "No quantized weights were exported. Calibration may have skipped every layer " + "(check the disable_quantizer regex) or `mtq.quantize` did not actually wrap any " + "weight quantizers." + ) + + print("\nSaving pipeline with FP8 transformer...") + _save_pipeline_with_fp8_transformer(pipe, model_path, output_dir) + _patch_quant_config(output_dir, weight_block_size=weight_block_size) + print(f"Saved to: {output_dir}") + _summarize_export(output_dir) + + print("\nNext: validate the checkpoint with vllm-omni:") + if variant == "i2v": + print( + " python examples/offline_inference/image_to_video/image_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A subject from the reference image moves through the scene.' \\\n" + " --image \\\n" + f" --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 6.0 --seed 42 \\\n" + " --output outputs/hv15_i2v_modelopt_fp8.mp4 \\\n" + " --enforce-eager" + ) + else: + print( + " python examples/offline_inference/text_to_video/text_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 6.0 --seed 42 \\\n" + " --output outputs/hv15_t2v_modelopt_fp8.mp4 \\\n" + " --enforce-eager" + ) + print( + "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " + "checkpoint's config.json has modelopt metadata.)" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/quantization/quantize_wan2_2_modelopt_fp8.py b/examples/quantization/quantize_wan2_2_modelopt_fp8.py new file mode 100644 index 00000000000..a6959b7ff02 --- /dev/null +++ b/examples/quantization/quantize_wan2_2_modelopt_fp8.py @@ -0,0 +1,716 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantize Wan2.2 to a ModelOpt FP8 Hugging Face checkpoint. + +Calibrates the DiT transformer(s) using a small video prompt set and exports a +diffusers-style directory whose transformer(s) carry ModelOpt FP8 metadata. +The exported checkpoint is consumable by vllm-omni's ModelOpt FP8 adapter +(see vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt.py). + +Layers kept full precision match the #2728 / #2795 pattern: condition embedder +(time/text/image), patch embedding, modulation (scale_shift_table), final +norm + proj_out, and sequence-parallel helpers. All attention + FFN linears +are quantized — static calibration handles the numerics that online FP8 +couldn't (see #2920 ablation). + +Supported targets: +- `Wan-AI/Wan2.2-TI2V-5B-Diffusers` (single-transformer, 80GB BF16 fits one GPU) +- `Wan-AI/Wan2.2-T2V-A14B-Diffusers` (MoE, two transformers, needs 2+ GPUs BF16) +- `Wan-AI/Wan2.2-I2V-A14B-Diffusers` (MoE, two transformers, needs 2+ GPUs BF16) + +For VACE variants (Wan-AI/Wan2.X-VACE-*), use the dedicated script +`quantize_wan2_2_vace_modelopt_fp8.py` instead + +For MoE A14B variants the diffusers pipeline routes between `transformer` (high +noise, t >= boundary_timestep) and `transformer_2` (low noise) automatically +based on `boundary_ratio` from `model_index.json`. A single calibration run +collects amax statistics for both via timestep-conditioned forward passes. + +For I2V variants diffusers' WanImageToVideoPipeline takes a required `image` +kwarg, so calibration must pair every prompt with a reference image — pass +`--is-i2v` together with `--reference-images `. + +Example(TI2V-5B): + python examples/quantization/quantize_wan2_2_modelopt_fp8.py \ + --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --output ./wan22-ti2v-modelopt-fp8 \ + --overwrite +Example(T2V-A14B): + python examples/quantization/quantize_wan2_2_modelopt_fp8.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --output ./wan22-t2v-modelopt-fp8 \ + --calib-boundary-ratio 0.5 \ + --overwrite +Example(I2V-A14B): + python examples/quantization/quantize_wan2_2_modelopt_fp8.py \ + --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --output ./wan22-i2v-modelopt-fp8 \ + --is-i2v --reference-images /path/to/ref_images/ \ + --calib-boundary-ratio 0.5 \ + --overwrite +""" + +from __future__ import annotations + +import argparse +import copy +import json +import re +import shutil +import sys +from pathlib import Path +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +DEFAULT_PROMPTS = [ + "A dog running across a field of golden wheat.", + "An astronaut riding a horse across the surface of Mars, red dust swirling, cinematic wide shot.", + "A hummingbird hovering in front of a vibrant red flower, slow motion, macro shot.", + "A crackling campfire at night under a starry sky, sparks rising into the dark.", + "An underwater shot of a coral reef with tropical fish swimming by, sun rays piercing the water.", + "A close-up of a blooming rose covered in morning dew, soft natural light.", + "A peaceful mountain village at dawn, mist rolling over the rooftops, cinematic establishing shot.", + "A skateboarder doing a kickflip in an urban plaza, slow motion, golden hour lighting.", +] + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--model", required=True, help="Input Wan2.2 diffusers directory or HF id.") + p.add_argument("--output", required=True, help="Output directory for the ModelOpt FP8 checkpoint.") + p.add_argument("--dtype", choices=("bfloat16", "float16"), default="bfloat16") + p.add_argument("--height", type=int, default=704, help="Calibration video height (Wan2.2 TI2V-5B native: 704).") + p.add_argument("--width", type=int, default=1280, help="Calibration video width (Wan2.2 TI2V-5B native: 1280).") + p.add_argument( + "--num-frames", + type=int, + default=49, + help="Frames per calibration sample. 49 matches the typical short benchmark; " + "use 17 to reduce memory pressure during calibration.", + ) + p.add_argument("--guidance-scale", type=float, default=5.0) + p.add_argument( + "--calib-steps", + type=int, + default=10, + help="Denoising steps per calibration prompt (10 is enough for amax statistics).", + ) + p.add_argument( + "--calib-size", + type=int, + default=8, + help="How many prompts to use for calibration. It is now decoupled with " + "number of DEFAULT_PROMPTS, i.e. type any size you like", + ) + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--prompt", + action="append", + default=[], + help="Custom calibration prompt. Repeat to provide multiple.", + ) + p.add_argument( + "--quantize-mha", + action="store_true", + help="Enable FP8 attention K/V/softmax quantizers. Off by default — Wan2.2's long attention " + "sequences amplified FP8 drift in the online ablation (see #2920).", + ) + p.add_argument( + "--weight-block-size", + type=str, + default=None, + help="Per-block weight quantization as 'M,N'. Only '128,128' is accepted because upstream " + "vLLM's ModelOptFp8PbWoLinearMethod hardcodes that block shape. Default: per-tensor. " + "Block-wise saves checkpoints with FP8_PB_WO routing (per-block static weights + per-token-" + "group dynamic activations); per-tensor uses static FP8 with calibrated activation scales.", + ) + p.add_argument( + "--calib-boundary-ratio", + type=float, + default=None, + help="Pass-1-only boundary_ratio override for Wan2.2 MoE calibration. Only takes " + "effect when the loaded pipeline has transformer_2. Lowering it (e.g. 0.5) shifts " + "more denoising steps onto `transformer` so its quantizers see a richer amax " + "sample WITHOUT bumping --calib-steps. Pass 2 always restores the model's " + "production boundary_ratio (A14B = 0.875) to keep transformer_2's amax in " + "production distribution. If unset, both passes use the production value (default).", + ) + p.add_argument( + "--is-i2v", + action="store_true", + help="Set when quantizing a Wan2.2 I2V model (e.g. Wan2.2-I2V-A14B-Diffusers). " + "diffusers' WanImageToVideoPipeline takes a required `image` kwarg, so calibration " + "must pair every prompt with a reference image — pass --reference-images.", + ) + p.add_argument( + "--reference-images", + type=str, + default=None, + help="Requires --is-i2v. Directory of jpg/jpeg/png/webp files (or a single image). " + "Every calibration sample is paired with a cycled ref image since image_embedder " + "is required, not optional, in I2V pipelines. Warning: one image per sample", + ) + p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") + return p + + +def _parse_block_size(spec: str | None) -> list[int] | None: + if spec is None: + return None + parts = [int(x) for x in spec.split(",") if x.strip()] + if len(parts) != 2: + raise SystemExit(f"--weight-block-size must be 'M,N' (2 ints), got {spec!r}") + return parts + + +def _require_modelopt() -> Any: + try: + import modelopt.torch.quantization as mtq + except ModuleNotFoundError as exc: + raise SystemExit( + "NVIDIA ModelOpt is not installed. Install with:\n" + " pip install 'nvidia-modelopt[all]'\n" + f"Original error: {exc}" + ) from exc + return mtq + + +def _ensure_paths(args: argparse.Namespace) -> tuple[str, Path]: + model_path = args.model + output_dir = Path(args.output).expanduser().resolve() + if output_dir.exists(): + if not args.overwrite: + raise SystemExit(f"Output directory already exists: {output_dir}\nPass --overwrite to replace it.") + shutil.rmtree(output_dir) + return model_path, output_dir + + +def _select_dtype(name: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float16": torch.float16}[name] + + +def _load_reference_images(spec: str | None) -> list[Any]: + """Load PIL.Image list from a directory or a single file path.""" + if spec is None: + return [] + from PIL import Image + + p = Path(spec).expanduser() + if not p.exists(): + raise SystemExit(f"--reference-images path not found: {p}") + if p.is_file(): + return [Image.open(p).convert("RGB")] + image_paths = sorted( + f for f in p.iterdir() if f.is_file() and f.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") + ) + if not image_paths: + raise SystemExit(f"No image files (jpg/jpeg/png/webp) found in {p}") + return [Image.open(f).convert("RGB") for f in image_paths] + + +def _build_calib_samples( + args: argparse.Namespace, + is_i2v: bool, + ref_images: list[Any], +) -> list[tuple[str, Any]]: + """Build calibration (prompt, reference_image_or_None) pairs. + + - Non-I2V (T2V/TI2V/A14B-T2V): every sample is (prompt, None). + - I2V: every sample paired with a cycled ref image (image kwarg is required + by diffusers' WanImageToVideoPipeline). Prompt pool is DEFAULT_PROMPTS + since the image dominates the visual signal — text mainly drives motion. + """ + if args.calib_size <= 0: + raise SystemExit("--calib-size must be positive.") + + prompts = args.prompt or DEFAULT_PROMPTS + if is_i2v: + # ref_images is guaranteed non-empty by main()'s validation (--is-i2v + # requires --reference-images). + return [(prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(prompts)] + return [(prompt, None) for prompt in prompts] + + +# Layers to KEEP at full precision. Wan2.2's module naming: +# - condition_embedder: time_embedder, time_proj, text_embedder, image_embedder (I2V) +# - patch_embedding: Conv3dLayer (already not Linear, belt-and-suspenders skip) +# - scale_shift_table: nn.Parameter modulation (not Linear, but pattern guard) +# - norm_out: AdaLayerNorm final +# - proj_out: final nn.Linear +# - timestep_proj_prepare / output_scale_shift_prepare: SP helpers +def _filter_func_wan22(name: str) -> bool: + pattern = re.compile( + r"(proj_out.*|" + r".*(condition_embedder|patch_embedding|" + r"norm_out|scale_shift_table|" + r"timestep_proj_prepare|output_scale_shift_prepare).*)" + ) + return pattern.match(name) is not None + + +def _mha_filter_func(name: str) -> bool: + pattern = re.compile( + r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer|bmm2_output_quantizer).*" + ) + return pattern.match(name) is not None + + +def _disable_known_problematic_quantizers(mtq: Any, backbone: torch.nn.Module, *, quantize_mha: bool) -> None: + if not hasattr(mtq, "disable_quantizer"): + return + mtq.disable_quantizer(backbone, _filter_func_wan22) + if not quantize_mha: + mtq.disable_quantizer(backbone, _mha_filter_func) + + +def _move_tensor(value: Any, device: torch.device) -> Any: + if isinstance(value, torch.Tensor): + return value.to(device) + if isinstance(value, (tuple, list)): + moved = [_move_tensor(v, device) for v in value] + return type(value)(moved) + return value + + +def _make_input_device_hook(target_device: torch.device): + """Pre-hook that moves all tensor args/kwargs onto the module's device.""" + + def pre_hook(_module, args, kwargs): + new_args = tuple(_move_tensor(a, target_device) for a in args) + new_kwargs = {k: _move_tensor(v, target_device) for k, v in kwargs.items()} + return new_args, new_kwargs + + return pre_hook + + +def _make_output_device_hook(primary_device: torch.device): + """Post-hook that moves outputs back to the pipeline's primary device.""" + + def post_hook(_module, _args, output): + return _move_tensor(output, primary_device) + + return post_hook + + +def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: + pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) + if hasattr(pipe, "set_progress_bar_config"): + pipe.set_progress_bar_config(disable=True) + + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None and torch.cuda.device_count() >= 2: + # diffusers' WanPipeline routes between the two by boundary_timestep but does + # NOT transfer activations across devices, so this case bridge transformer_2 with + # forward hooks: pre-hook moves inputs cuda:0 -> cuda:1, post-hook moves + # outputs back cuda:1 -> cuda:0. The pipeline then sees a uniform cuda:0 + # state and scheduler.step works without modification. + primary = torch.device("cuda:0") + secondary = torch.device("cuda:1") + pipe.transformer.to(primary) + transformer_2.to(secondary) + for component_name in ("text_encoder", "vae", "image_encoder"): + component = getattr(pipe, component_name, None) + if component is not None: + component.to(primary) + transformer_2.register_forward_pre_hook(_make_input_device_hook(secondary), with_kwargs=True) + transformer_2.register_forward_hook(_make_output_device_hook(primary)) + print(f" device map: transformer={primary}, transformer_2={secondary} (cross-device hooks installed)") + else: + pipe.to("cuda") + return pipe + + +def _build_forward_loop( + pipe: DiffusionPipeline, + args: argparse.Namespace, + samples: list[tuple[str, Any]], +): + """Build a forward_loop that drives `pipe` over the calibration samples. + + Samples carrying a reference image are forwarded with `image=PIL.Image` + (the kwarg expected by diffusers' WanImageToVideoPipeline). Samples with + ref=None call pipe(prompt=...) — the standard T2V path. + """ + generator = torch.Generator(device="cuda") + + # Try setting guidance on the pipeline's guider if present (newer diffusers APIs). + guider = getattr(pipe, "guider", None) + if guider is not None and hasattr(guider, "guidance_scale"): + try: + guider.guidance_scale = args.guidance_scale + except Exception: + pass + + base_kwargs = dict( + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.calib_steps, + output_type="latent", + ) + + def forward_loop(*_unused_args, **_unused_kwargs) -> None: + with torch.inference_mode(): + for idx, (prompt, ref_image) in enumerate(samples): + generator.manual_seed(args.seed + idx) + kwargs = dict(base_kwargs) + if ref_image is not None: + kwargs["image"] = ref_image + # Try with guidance_scale first; fall back without on TypeError + # for pipelines that take CFG via guider config only. + try: + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **kwargs) + except TypeError as exc: + if "guidance_scale" not in str(exc): + raise + pipe(prompt=prompt, generator=generator, **kwargs) + + return forward_loop + + +def _summarize_export(output_dir: Path, subfolder: str = "transformer") -> None: + cfg_path = output_dir / subfolder / "config.json" + if not cfg_path.exists(): + print(f"[warn] {cfg_path} missing.", file=sys.stderr) + return + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print(f"[warn] No quantization_config in {subfolder}/config.json.", file=sys.stderr) + return + print(f"Export summary ({subfolder}):") + print(f" quant_method: {qc.get('quant_method')}") + print(f" quant_algo: {qc.get('quant_algo')}") + producer = qc.get("producer") + if isinstance(producer, dict): + print(f" producer: {producer.get('name')} {producer.get('version')}") + print(f" config path: {cfg_path}") + + +def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtype) -> int: + """Convert in-memory weights of quantized modules to actual FP8 storage. + + `export_hf_checkpoint` skips this step for unknown model types (Wan2.2 isn't + in ModelOpt's recognized-model registry), so we must call the per-weight + export helper ourselves. Same workaround as the HunyuanVideo-1.5 / HunyuanImage-3 + calibration helpers. + """ + from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + get_quantization_format, + quantizer_attr_names, + weight_attr_names, + ) + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + exported = 0 + for name, module in backbone.named_modules(): + try: + quantization_format = get_quantization_format(module) + except Exception as exc: + print(f"[warn] Could not inspect quantization format for {name}: {exc}", file=sys.stderr) + continue + if quantization_format == QUANTIZATION_NONE: + continue + for weight_name in weight_attr_names(module): + quantizer_attrs = quantizer_attr_names(weight_name) + weight_quantizer = getattr(module, quantizer_attrs.weight_quantizer, None) + if weight_quantizer is None or not getattr(weight_quantizer, "is_enabled", False): + continue + _export_quantized_weight(module, dtype, weight_name) + exported += 1 + return exported + + +def _wan22_quant_config_block(weight_block_size: list[int] | None = None) -> dict: + """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). + + For per-block weight quantization,upstream's FP8_PB_WO hardcodes _WEIGHT_BLOCK_SIZE = (128, 128), so any other + block shape produces a checkpoint vLLM cannot serve. + """ + if weight_block_size is not None and tuple(weight_block_size) != (128, 128): + raise ValueError( + f"--weight-block-size {tuple(weight_block_size)} not supported: upstream vLLM's " + "ModelOptFp8PbWoLinearMethod hardcodes (128, 128). Pass '128,128' or omit the flag." + ) + + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} + if weight_block_size is not None: + weights_cfg["strategy"] = "block" + weights_cfg["block_structure"] = f"{weight_block_size[0]}x{weight_block_size[1]}" + return { + "config_groups": { + "group_0": { + "input_activations": {"dynamic": False, "num_bits": 8, "type": "float"}, + "weights": weights_cfg, + "targets": ["Linear"], + } + }, + "ignore": [ + "condition_embedder*", + "norm_out*", + "output_scale_shift_prepare*", + "patch_embedding*", + "proj_out*", + "scale_shift_table*", + "timestep_proj_prepare*", + ], + "producer": {"name": "modelopt"}, + "quant_algo": "FP8_PB_WO" if weight_block_size is not None else "FP8", + "quant_method": "modelopt", + } + + +def _patch_quant_config( + output_dir: Path, + subfolder: str = "transformer", + weight_block_size: list[int] | None = None, +) -> None: + """Inject quant_algo: FP8 + config_groups into /config.json so + vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8. + + For Wan2.2 MoE (T2V/I2V-A14B), call once per transformer subfolder + (`transformer` and `transformer_2`). + """ + cfg_path = output_dir / subfolder / "config.json" + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + new_qc = _wan22_quant_config_block(weight_block_size=weight_block_size) + existing = cfg.get("quantization_config") + if isinstance(existing, dict): + producer = existing.get("producer") + if isinstance(producer, dict): + new_qc["producer"] = producer + + cfg["quantization_config"] = new_qc + with cfg_path.open("w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2) + + +def _save_pipeline_with_fp8_transformers( + pipe: DiffusionPipeline, + model_path: str, + output_dir: Path, + max_shard_size: str = "5GB", +) -> None: + """Copy source dir verbatim minus transformer/(_2), then save quantized transformer(s). + + For Wan2.2 MoE (T2V/I2V-A14B), `pipe.transformer_2` is also saved into the + `transformer_2/` subfolder. Single-transformer variants (TI2V-5B) skip it. + """ + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict + + src = Path(model_path) + if not src.exists(): + from huggingface_hub import snapshot_download + + src = Path(snapshot_download(model_path)) + + if output_dir.exists(): + shutil.rmtree(output_dir) + shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer", "transformer_2")) + + backbones: list[tuple[str, torch.nn.Module]] = [("transformer", pipe.transformer)] + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None: + backbones.append(("transformer_2", transformer_2)) + + for subfolder, backbone in backbones: + out = output_dir / subfolder + with hide_quantizers_from_state_dict(backbone): + backbone.save_pretrained( + str(out), + safe_serialization=True, + max_shard_size=max_shard_size, + ) + + +def _calibrate( + backbone: torch.nn.Module, + label: str, + *, + mtq: Any, + quant_config: dict, + forward_loop, + quantize_mha: bool, +) -> torch.nn.Module: + """Wrap one transformer backbone with quantizers and run calibration. + + Returns the (possibly replaced) backbone module so the caller can rebind + `pipe.transformer` / `pipe.transformer_2` to the wrapped instance. The + backbone's weights remain in their original dtype here — call + `_force_export` afterwards to commit FP8 storage. + """ + print(f"\nCalibrating {label}...") + quantized = mtq.quantize(backbone, quant_config, forward_loop) + if quantized is not None: + backbone = quantized + _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=quantize_mha) + return backbone + + +def _force_export(backbone: torch.nn.Module, label: str, dtype: torch.dtype) -> None: + """Convert calibrated weights to actual FP8 storage.""" + print(f"\nForcing FP8 weight serialization for {label} (Wan2.2 isn't in ModelOpt's") + print("recognized-model registry, so we call the per-weight export helper ourselves)...") + exported = _force_export_quantized_weights(backbone, dtype) + print(f" -> {exported} weights converted to FP8 in {label}") + if exported == 0: + raise SystemExit( + f"No quantized weights were exported in {label}. Calibration may have skipped every " + "layer (check the disable_quantizer regex) or `mtq.quantize` did not actually wrap " + "any weight quantizers." + ) + + +def main() -> None: + args = _build_parser().parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for ModelOpt FP8 quantization.") + + mtq = _require_modelopt() + model_path, output_dir = _ensure_paths(args) + dtype = _select_dtype(args.dtype) + weight_block_size = _parse_block_size(args.weight_block_size) + + if args.reference_images is not None and not args.is_i2v: + raise SystemExit("--reference-images requires --is-i2v.") + if args.is_i2v and args.reference_images is None: + raise SystemExit( + "--is-i2v requires --reference-images: diffusers' WanImageToVideoPipeline " + "takes a required `image` kwarg, so calibration must pair every prompt with " + "a reference image." + ) + ref_images = _load_reference_images(args.reference_images) if args.is_i2v else [] + samples = _build_calib_samples(args, args.is_i2v, ref_images) + sample_label = f"I2V={len(samples)}" if args.is_i2v else f"T2V={len(samples)}" + + print("Quantization plan:") + print(f" input: {args.model}") + print(f" output: {output_dir}") + print(f" dtype: {dtype}") + print(f" height/width: {args.height}x{args.width}") + print(f" num_frames: {args.num_frames}") + print(f" calib_size: {len(samples)} ({sample_label})") + print(f" calib_steps: {args.calib_steps}") + print(f" quantize_mha: {args.quantize_mha}") + print(f" is_i2v: {args.is_i2v}") + if args.is_i2v: + print(f" reference imgs: {len(ref_images)}") + print( + f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" + ) + + pipe = _load_pipeline(model_path, dtype) + is_dual = getattr(pipe, "transformer_2", None) is not None + if is_dual: + print(" detected MoE A14B variant (transformer + transformer_2)") + + # Capture the model's production boundary_ratio (from model_index.json) so + # we can restore it before pass 2. --calib-boundary-ratio only overrides + # pass 1 to give `transformer` more amax samples; pass 2 must run at the + # production boundary so `transformer_2` calibrates on the same noise + # distribution it will see at inference time. + production_boundary = pipe.config.get("boundary_ratio") if is_dual else None + + quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + if weight_block_size is not None: + quant_config["quant_cfg"]["*weight_quantizer"] = { + "num_bits": (4, 3), + "block_sizes": {-1: weight_block_size[1], -2: weight_block_size[0]}, + } + print( + f" -> overriding weight quantizer with block_sizes={weight_block_size} " + f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" + ) + + forward_loop = _build_forward_loop(pipe, args, samples) + + # Single-transformer (TI2V-5B) does one pass; MoE A14B variants do two. + # The diffusers Wan22 pipeline routes between transformer (high noise) and + # transformer_2 (low noise) by boundary_timestep, so each forward_loop run + # exercises the backbone currently being calibrated. mtq.quantize wraps + # quantizers and then drives the forward_loop to collect amax statistics. + # + # Calibration must complete for BOTH backbones BEFORE any force_export call: + # Before _force_export, transformer's weights must still be BF16 at that point. + if is_dual and args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=args.calib_boundary_ratio) + print( + f"\n pass 1 boundary_ratio: {args.calib_boundary_ratio} " + f"(override of production {production_boundary} for transformer sample boost)" + ) + + pipe.transformer = _calibrate( + pipe.transformer, + "transformer", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, + ) + if is_dual: + if args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=production_boundary) + print( + f"\n pass 2 boundary_ratio: {production_boundary} " + "(restored to production for transformer_2 in-distribution calibration)" + ) + pipe.transformer_2 = _calibrate( + pipe.transformer_2, + "transformer_2", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, + ) + + _force_export(pipe.transformer, "transformer", dtype) + if is_dual: + _force_export(pipe.transformer_2, "transformer_2", dtype) + + print("\nSaving pipeline with FP8 transformer(s)...") + _save_pipeline_with_fp8_transformers(pipe, model_path, output_dir) + _patch_quant_config(output_dir, subfolder="transformer", weight_block_size=weight_block_size) + if is_dual: + _patch_quant_config(output_dir, subfolder="transformer_2", weight_block_size=weight_block_size) + print(f"Saved to: {output_dir}") + _summarize_export(output_dir, subfolder="transformer") + if is_dual: + _summarize_export(output_dir, subfolder="transformer_2") + + print("\nNext: validate the checkpoint with vllm-omni:") + if args.is_i2v: + print( + " python examples/offline_inference/image_to_video/image_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A subject from the reference image moves through the scene.' \\\n" + " --image \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 --seed 42 \\\n" + " --output outputs/wan22_i2v_modelopt_fp8.mp4" + ) + else: + print( + " python examples/offline_inference/text_to_video/text_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 --seed 42 \\\n" + " --output outputs/wan22_modelopt_fp8.mp4" + ) + print( + "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " + "checkpoint's config.json has modelopt metadata.)" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py b/examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py new file mode 100644 index 00000000000..6a3f9143001 --- /dev/null +++ b/examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py @@ -0,0 +1,701 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantize Wan VACE models to a ModelOpt FP8 Hugging Face checkpoint. Support modes: + + - T2V : prompt only (vace_context auto-filled with zeros + mask=1) + - R2V : prompt + reference_images=[PIL.Image] + - I2V : prompt + reference_images=[PIL.Image], same as R2V in calibration step + +This script currently calibrates with **T2V + R2V + I2V** samples. The other modes +(I2V/FLF2V/inpaint) require encoded video + mask inputs and can be wired in +later by extending `_build_calib_samples`. + +Layers kept full precision match the Wan2.2 pattern: condition embedder +(time/text/image), patch embedding, modulation (scale_shift_table), final +norm + proj_out, and sequence-parallel helpers. All attention + FFN linears +are quantized — including the vace_blocks' own attention/FFN linears (since +they're standard `WanTransformerBlock` subclasses). + +Supported targets: +- `Wan-AI/Wan2.1-VACE-1.3B-diffusers` (single-transformer, ~10GB BF16) +- `Wan-AI/Wan2.1-VACE-14B-diffusers` (single-transformer, ~38GB BF16) +- `Wan-AI/Wan2.2-VACE-A14B-Diffusers` (MoE + VACE, dual-transformer; needs 2+ GPUs BF16; Model not released yet, but the wiring is ready) + +For dual-transformer VACE the diffusers pipeline routes between `transformer` +and `transformer_2` by `boundary_timestep` exactly like Wan2.2 MoE T2V/I2V. + +Example (VACE T2V calibration, no reference images): + python examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py \ + --model Wan-AI/Wan2.1-VACE-1.3B-diffusers \ + --output ./wan21-vace-1.3b-fp8 \ + --overwrite + +Example (VACE T2V + R2V mix, with reference images): + python examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py \ + --model Wan-AI/Wan2.1-VACE-14B-diffusers \ + --output ./wan21-vace-14b-fp8 \ + --reference-images /path/to/ref_images/ \ + --overwrite +""" + +from __future__ import annotations + +import argparse +import copy +import json +import re +import shutil +import sys +from pathlib import Path +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +DEFAULT_PROMPTS = [ + "A dog running across a field of golden wheat.", + "An astronaut riding a horse across the surface of Mars, red dust swirling, cinematic wide shot.", + "A hummingbird hovering in front of a vibrant red flower, slow motion, macro shot.", + "A crackling campfire at night under a starry sky, sparks rising into the dark.", + "An underwater shot of a coral reef with tropical fish swimming by, sun rays piercing the water.", + "A close-up of a blooming rose covered in morning dew, soft natural light.", + "A peaceful mountain village at dawn, mist rolling over the rooftops, cinematic establishing shot.", + "A skateboarder doing a kickflip in an urban plaza, slow motion, golden hour lighting.", +] + +# R2V prompts pair with --reference-images. Phrasing explicitly references "the +# subject from the reference image" so prompt and ref_image are semantically +# coupled — mimics how users actually write R2V prompts in production. +VACE_DEFAULT_PROMPTS_R2V = [ + "The subject from the reference image walks confidently through a snowy forest at dusk.", + "Recreate the reference subject dancing under spinning disco lights in a vibrant nightclub.", + "The reference subject sails across a calm ocean at golden hour, sun glinting off the water.", + "Render the reference subject in a cyberpunk cityscape at night, neon reflections on rainy streets.", +] + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--model", required=True, help="Input Wan VACE diffusers directory or HF id.") + p.add_argument("--output", required=True, help="Output directory for the ModelOpt FP8 checkpoint.") + p.add_argument("--dtype", choices=("bfloat16", "float16"), default="bfloat16") + p.add_argument("--height", type=int, default=480, help="Calibration video height (VACE 480p default).") + p.add_argument("--width", type=int, default=832, help="Calibration video width (VACE 480p default).") + p.add_argument( + "--num-frames", + type=int, + default=33, + help="Frames per calibration sample. Smaller frame counts reduce memory pressure during " + "calibration; amax statistics are largely independent of frame count.", + ) + p.add_argument("--guidance-scale", type=float, default=5.0) + p.add_argument( + "--calib-steps", + type=int, + default=10, + help="Denoising steps per calibration prompt (10 is enough for amax statistics).", + ) + p.add_argument("--calib-size", type=int, default=8, help="How many prompts to use for calibration.") + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--prompt", + action="append", + default=[], + help="Custom calibration prompt. Repeat to provide multiple. When --reference-images is " + "set, every custom prompt is paired with a cycled ref image (assumes R2V phrasing).", + ) + p.add_argument( + "--quantize-mha", + action="store_true", + help="Enable FP8 attention K/V/softmax quantizers. Off by default — Wan's long attention " + "sequences amplified FP8 drift in the online ablation (see #2920).", + ) + p.add_argument( + "--weight-block-size", + type=str, + default=None, + help="Per-block weight quantization as 'M,N'. Only '128,128' is accepted because upstream " + "vLLM's ModelOptFp8PbWoLinearMethod hardcodes that block shape. Default: per-tensor. " + "Block-wise saves checkpoints with FP8_PB_WO routing (per-block static weights + per-token-" + "group dynamic activations); per-tensor uses static FP8 with calibrated activation scales.", + ) + p.add_argument( + "--calib-boundary-ratio", + type=float, + default=None, + help="Pass-1-only boundary_ratio override for dual-transformer VACE (e.g. Wan2.2-VACE-A14B). " + "Lowering it (e.g. 0.5) shifts more denoising steps onto `transformer` so its quantizers " + "see a richer amax sample WITHOUT bumping --calib-steps. Pass 2 always restores the " + "model's production boundary_ratio. No-op for single-transformer VACE (Wan2.1-VACE-*).", + ) + p.add_argument( + "--reference-images", + type=str, + default=None, + help="Optional. Directory of jpg/jpeg/png/webp files (or a single image). When provided, " + "half the calibration samples become R2V (paired with cycled ref images, using " + "VACE_DEFAULT_PROMPTS_R2V) so vace_blocks' amax covers real ref-image latent " + "distributions; the other half stay T2V (zero-conditioning). When omitted, calibration " + "runs T2V-only — vace_blocks see auto-generated zero vace_context, which works but " + "amax is conservative for R2V-mode inference.", + ) + p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") + return p + + +def _parse_block_size(spec: str | None) -> list[int] | None: + if spec is None: + return None + parts = [int(x) for x in spec.split(",") if x.strip()] + if len(parts) != 2: + raise SystemExit(f"--weight-block-size must be 'M,N' (2 ints), got {spec!r}") + return parts + + +def _require_modelopt() -> Any: + try: + import modelopt.torch.quantization as mtq + except ModuleNotFoundError as exc: + raise SystemExit( + "NVIDIA ModelOpt is not installed. Install with:\n" + " pip install 'nvidia-modelopt[all]'\n" + f"Original error: {exc}" + ) from exc + return mtq + + +def _ensure_paths(args: argparse.Namespace) -> tuple[str, Path]: + model_path = args.model + output_dir = Path(args.output).expanduser().resolve() + if output_dir.exists(): + if not args.overwrite: + raise SystemExit(f"Output directory already exists: {output_dir}\nPass --overwrite to replace it.") + shutil.rmtree(output_dir) + return model_path, output_dir + + +def _select_dtype(name: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float16": torch.float16}[name] + + +def _load_reference_images(spec: str | None) -> list[Any]: + """Load PIL.Image list from a directory or a single file path.""" + if spec is None: + return [] + from PIL import Image + + p = Path(spec).expanduser() + if not p.exists(): + raise SystemExit(f"--reference-images path not found: {p}") + if p.is_file(): + return [Image.open(p).convert("RGB")] + image_paths = sorted( + f for f in p.iterdir() if f.is_file() and f.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") + ) + if not image_paths: + raise SystemExit(f"No image files (jpg/jpeg/png/webp) found in {p}") + return [Image.open(f).convert("RGB") for f in image_paths] + + +def _cycle_to_size(items: list, size: int) -> list: + if not items: + raise SystemExit("Cannot build calibration prompts: pool is empty.") + repeats = (size + len(items) - 1) // len(items) + return (items * repeats)[:size] + + +def _build_calib_samples( + args: argparse.Namespace, + ref_images: list[Any], +) -> list[tuple[str, Any]]: + """Build calibration (prompt, reference_image_or_None) pairs for VACE. + + - No --reference-images: T2V-only calibration. vace_blocks see auto-generated + zero vace_context (vae.encode(zeros) + mask=1). + - With --reference-images, no --prompt: half samples are T2V (DEFAULT_PROMPTS), + half are R2V (VACE_DEFAULT_PROMPTS_R2V paired with cycled ref images), + covering both zero- and real-conditioning extremes for vace_blocks' amax. + - With --reference-images and --prompt: every user prompt is paired with a + cycled ref image (assumes the user wrote R2V-style prompts). + """ + if args.calib_size <= 0: + raise SystemExit("--calib-size must be positive.") + + if not ref_images: + prompts = args.prompt or DEFAULT_PROMPTS + return [(p, None) for p in _cycle_to_size(prompts, args.calib_size)] + + custom_prompts = args.prompt or [] + if custom_prompts: + pool = _cycle_to_size(custom_prompts, args.calib_size) + return [(prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(pool)] + + n_r2v = min(args.calib_size // 2, len(ref_images)) + n_t2v = args.calib_size - n_r2v + t2v_pool = _cycle_to_size(DEFAULT_PROMPTS, n_t2v) if n_t2v > 0 else [] + r2v_pool = _cycle_to_size(VACE_DEFAULT_PROMPTS_R2V, n_r2v) if n_r2v > 0 else [] + samples: list[tuple[str, Any]] = [(p, None) for p in t2v_pool] + samples.extend((prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(r2v_pool)) + return samples + + +# Layers to KEEP at full precision. Wan VACE inherits the base Wan module +# naming (condition_embedder, patch_embedding, scale_shift_table, norm_out, +# proj_out, timestep_proj_prepare/output_scale_shift_prepare). vace_blocks +# carry their own proj_in/proj_out Linears (full path: vace_blocks.{i}.proj_*), +# which the regex below intentionally does NOT match — they are quantized +# alongside the rest of the vace_blocks' attention/FFN linears. +def _filter_func_wan22(name: str) -> bool: + pattern = re.compile( + r"(proj_out.*|" + r".*(condition_embedder|patch_embedding|" + r"norm_out|scale_shift_table|" + r"timestep_proj_prepare|output_scale_shift_prepare).*)" + ) + return pattern.match(name) is not None + + +def _mha_filter_func(name: str) -> bool: + pattern = re.compile( + r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer|bmm2_output_quantizer).*" + ) + return pattern.match(name) is not None + + +def _disable_known_problematic_quantizers(mtq: Any, backbone: torch.nn.Module, *, quantize_mha: bool) -> None: + if not hasattr(mtq, "disable_quantizer"): + return + mtq.disable_quantizer(backbone, _filter_func_wan22) + if not quantize_mha: + mtq.disable_quantizer(backbone, _mha_filter_func) + + +def _move_tensor(value: Any, device: torch.device) -> Any: + if isinstance(value, torch.Tensor): + return value.to(device) + if isinstance(value, (tuple, list)): + moved = [_move_tensor(v, device) for v in value] + return type(value)(moved) + return value + + +def _make_input_device_hook(target_device: torch.device): + """Pre-hook that moves all tensor args/kwargs onto the module's device.""" + + def pre_hook(_module, args, kwargs): + new_args = tuple(_move_tensor(a, target_device) for a in args) + new_kwargs = {k: _move_tensor(v, target_device) for k, v in kwargs.items()} + return new_args, new_kwargs + + return pre_hook + + +def _make_output_device_hook(primary_device: torch.device): + """Post-hook that moves outputs back to the pipeline's primary device.""" + + def post_hook(_module, _args, output): + return _move_tensor(output, primary_device) + + return post_hook + + +def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: + pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) + if hasattr(pipe, "set_progress_bar_config"): + pipe.set_progress_bar_config(disable=True) + + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None and torch.cuda.device_count() >= 2: + # diffusers' WanPipeline routes between the two by boundary_timestep but + # does NOT transfer activations across devices; bridge transformer_2 with + # forward hooks: pre-hook moves inputs cuda:0 -> cuda:1, post-hook moves + # outputs back cuda:1 -> cuda:0. The pipeline then sees a uniform cuda:0 + # state and scheduler.step works without modification. + primary = torch.device("cuda:0") + secondary = torch.device("cuda:1") + pipe.transformer.to(primary) + transformer_2.to(secondary) + for component_name in ("text_encoder", "vae", "image_encoder"): + component = getattr(pipe, component_name, None) + if component is not None: + component.to(primary) + transformer_2.register_forward_pre_hook(_make_input_device_hook(secondary), with_kwargs=True) + transformer_2.register_forward_hook(_make_output_device_hook(primary)) + print(f" device map: transformer={primary}, transformer_2={secondary} (cross-device hooks installed)") + else: + pipe.to("cuda") + return pipe + + +def _build_forward_loop( + pipe: DiffusionPipeline, + args: argparse.Namespace, + samples: list[tuple[str, Any]], +): + """Build a forward_loop that drives `pipe` over the calibration samples. + + Samples carrying a reference image are forwarded with `reference_images=[img]` + (the kwarg expected by diffusers' WanVACEPipeline). Samples with ref=None + call pipe(prompt=...) — diffusers VACE pipeline auto-fills vace_context with + zeros + mask=1 in this case (T2V mode). + """ + generator = torch.Generator(device="cuda") + + # Try setting guidance on the pipeline's guider if present (newer diffusers APIs). + guider = getattr(pipe, "guider", None) + if guider is not None and hasattr(guider, "guidance_scale"): + try: + guider.guidance_scale = args.guidance_scale + except Exception: + pass + + base_kwargs = dict( + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.calib_steps, + output_type="latent", + ) + + def forward_loop(*_unused_args, **_unused_kwargs) -> None: + with torch.inference_mode(): + for idx, (prompt, ref_image) in enumerate(samples): + generator.manual_seed(args.seed + idx) + kwargs = dict(base_kwargs) + if ref_image is not None: + kwargs["reference_images"] = [ref_image] + # Try with guidance_scale first; fall back without on TypeError + # for pipelines that take CFG via guider config only. + try: + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **kwargs) + except TypeError as exc: + if "guidance_scale" not in str(exc): + raise + pipe(prompt=prompt, generator=generator, **kwargs) + + return forward_loop + + +def _summarize_export(output_dir: Path, subfolder: str = "transformer") -> None: + cfg_path = output_dir / subfolder / "config.json" + if not cfg_path.exists(): + print(f"[warn] {cfg_path} missing.", file=sys.stderr) + return + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print(f"[warn] No quantization_config in {subfolder}/config.json.", file=sys.stderr) + return + print(f"Export summary ({subfolder}):") + print(f" quant_method: {qc.get('quant_method')}") + print(f" quant_algo: {qc.get('quant_algo')}") + producer = qc.get("producer") + if isinstance(producer, dict): + print(f" producer: {producer.get('name')} {producer.get('version')}") + print(f" config path: {cfg_path}") + + +def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtype) -> int: + """Convert in-memory weights of quantized modules to actual FP8 storage. + + `export_hf_checkpoint` skips this step for unknown model types (Wan VACE + isn't in ModelOpt's recognized-model registry), so we must call the + per-weight export helper ourselves. + """ + from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + get_quantization_format, + quantizer_attr_names, + weight_attr_names, + ) + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + exported = 0 + for name, module in backbone.named_modules(): + try: + quantization_format = get_quantization_format(module) + except Exception as exc: + print(f"[warn] Could not inspect quantization format for {name}: {exc}", file=sys.stderr) + continue + if quantization_format == QUANTIZATION_NONE: + continue + for weight_name in weight_attr_names(module): + quantizer_attrs = quantizer_attr_names(weight_name) + weight_quantizer = getattr(module, quantizer_attrs.weight_quantizer, None) + if weight_quantizer is None or not getattr(weight_quantizer, "is_enabled", False): + continue + _export_quantized_weight(module, dtype, weight_name) + exported += 1 + return exported + + +def _wan22_quant_config_block(weight_block_size: list[int] | None = None) -> dict: + """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). + For per-block weight quantization, upstream's FP8_PB_WO hardcodes _WEIGHT_BLOCK_SIZE = (128, 128), so any other + block shape produces a checkpoint vLLM cannot serve. + """ + if weight_block_size is not None and tuple(weight_block_size) != (128, 128): + raise ValueError( + f"--weight-block-size {tuple(weight_block_size)} not supported: upstream vLLM's " + "ModelOptFp8PbWoLinearMethod hardcodes (128, 128). Pass '128,128' or omit the flag." + ) + + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} + if weight_block_size is not None: + weights_cfg["strategy"] = "block" + weights_cfg["block_structure"] = f"{weight_block_size[0]}x{weight_block_size[1]}" + return { + "config_groups": { + "group_0": { + "input_activations": {"dynamic": False, "num_bits": 8, "type": "float"}, + "weights": weights_cfg, + "targets": ["Linear"], + } + }, + "ignore": [ + "condition_embedder*", + "norm_out*", + "output_scale_shift_prepare*", + "patch_embedding*", + "proj_out*", + "scale_shift_table*", + "timestep_proj_prepare*", + ], + "producer": {"name": "modelopt"}, + "quant_algo": "FP8_PB_WO" if weight_block_size is not None else "FP8", + "quant_method": "modelopt", + } + + +def _patch_quant_config( + output_dir: Path, + subfolder: str = "transformer", + weight_block_size: list[int] | None = None, +) -> None: + """Inject quant_algo: FP8 + config_groups into /config.json so + vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8. + """ + cfg_path = output_dir / subfolder / "config.json" + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + new_qc = _wan22_quant_config_block(weight_block_size=weight_block_size) + existing = cfg.get("quantization_config") + if isinstance(existing, dict): + producer = existing.get("producer") + if isinstance(producer, dict): + new_qc["producer"] = producer + + cfg["quantization_config"] = new_qc + with cfg_path.open("w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2) + + +def _save_pipeline_with_fp8_transformers( + pipe: DiffusionPipeline, + model_path: str, + output_dir: Path, + max_shard_size: str = "5GB", +) -> None: + """Copy source dir verbatim minus transformer/(_2), then save quantized transformer(s).""" + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict + + src = Path(model_path) + if not src.exists(): + from huggingface_hub import snapshot_download + + src = Path(snapshot_download(model_path)) + + if output_dir.exists(): + shutil.rmtree(output_dir) + shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer", "transformer_2")) + + backbones: list[tuple[str, torch.nn.Module]] = [("transformer", pipe.transformer)] + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None: + backbones.append(("transformer_2", transformer_2)) + + for subfolder, backbone in backbones: + out = output_dir / subfolder + with hide_quantizers_from_state_dict(backbone): + backbone.save_pretrained( + str(out), + safe_serialization=True, + max_shard_size=max_shard_size, + ) + + +def _calibrate( + backbone: torch.nn.Module, + label: str, + *, + mtq: Any, + quant_config: dict, + forward_loop, + quantize_mha: bool, +) -> torch.nn.Module: + """Wrap one transformer backbone with quantizers and run calibration.""" + print(f"\nCalibrating {label}...") + quantized = mtq.quantize(backbone, quant_config, forward_loop) + if quantized is not None: + backbone = quantized + _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=quantize_mha) + return backbone + + +def _force_export(backbone: torch.nn.Module, label: str, dtype: torch.dtype) -> None: + """Convert calibrated weights to actual FP8 storage.""" + print(f"\nForcing FP8 weight serialization for {label} (Wan VACE isn't in ModelOpt's") + print("recognized-model registry, so we call the per-weight export helper ourselves)...") + exported = _force_export_quantized_weights(backbone, dtype) + print(f" -> {exported} weights converted to FP8 in {label}") + if exported == 0: + raise SystemExit( + f"No quantized weights were exported in {label}. Calibration may have skipped every " + "layer (check the disable_quantizer regex) or `mtq.quantize` did not actually wrap " + "any weight quantizers." + ) + + +def main() -> None: + args = _build_parser().parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for ModelOpt FP8 quantization.") + + mtq = _require_modelopt() + model_path, output_dir = _ensure_paths(args) + dtype = _select_dtype(args.dtype) + weight_block_size = _parse_block_size(args.weight_block_size) + + ref_images = _load_reference_images(args.reference_images) + samples = _build_calib_samples(args, ref_images) + n_r2v = sum(1 for _, ref in samples if ref is not None) + n_t2v = len(samples) - n_r2v + + print("Quantization plan:") + print(f" input: {args.model}") + print(f" output: {output_dir}") + print(f" dtype: {dtype}") + print(f" height/width: {args.height}x{args.width}") + print(f" num_frames: {args.num_frames}") + print(f" calib_size: {len(samples)} (T2V={n_t2v}, R2V={n_r2v})") + print(f" calib_steps: {args.calib_steps}") + print(f" quantize_mha: {args.quantize_mha}") + print(f" reference imgs: {len(ref_images)}") + print( + f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" + ) + + pipe = _load_pipeline(model_path, dtype) + is_dual = getattr(pipe, "transformer_2", None) is not None + if is_dual: + print(" detected dual-transformer VACE variant (transformer + transformer_2)") + + # Production boundary_ratio captured from model_index.json so pass 2 can + # restore it after a --calib-boundary-ratio override on pass 1. + production_boundary = pipe.config.get("boundary_ratio") if is_dual else None + + quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + if weight_block_size is not None: + quant_config["quant_cfg"]["*weight_quantizer"] = { + "num_bits": (4, 3), + "block_sizes": {-1: weight_block_size[1], -2: weight_block_size[0]}, + } + print( + f" -> overriding weight quantizer with block_sizes={weight_block_size} " + f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" + ) + + forward_loop = _build_forward_loop(pipe, args, samples) + + # Single-transformer VACE does one pass; dual-transformer (Wan2.2-VACE-A14B) + # does two. Calibration must complete for BOTH backbones BEFORE any + # _force_export call — transformer's weights must still be BF16 at that point. + if is_dual and args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=args.calib_boundary_ratio) + print( + f"\n pass 1 boundary_ratio: {args.calib_boundary_ratio} " + f"(override of production {production_boundary} for transformer sample boost)" + ) + + pipe.transformer = _calibrate( + pipe.transformer, + "transformer", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, + ) + if is_dual: + if args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=production_boundary) + print( + f"\n pass 2 boundary_ratio: {production_boundary} " + "(restored to production for transformer_2 in-distribution calibration)" + ) + pipe.transformer_2 = _calibrate( + pipe.transformer_2, + "transformer_2", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, + ) + + _force_export(pipe.transformer, "transformer", dtype) + if is_dual: + _force_export(pipe.transformer_2, "transformer_2", dtype) + + print("\nSaving pipeline with FP8 transformer(s)...") + _save_pipeline_with_fp8_transformers(pipe, model_path, output_dir) + _patch_quant_config(output_dir, subfolder="transformer", weight_block_size=weight_block_size) + if is_dual: + _patch_quant_config(output_dir, subfolder="transformer_2", weight_block_size=weight_block_size) + print(f"Saved to: {output_dir}") + _summarize_export(output_dir, subfolder="transformer") + if is_dual: + _summarize_export(output_dir, subfolder="transformer_2") + + print("\nNext: validate the checkpoint with vllm-omni:") + if n_r2v > 0: + print( + " python examples/offline_inference/vace/vace_video_generation.py \\\n" + " --mode r2v \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'The subject from the reference image walks through a snowy forest at dusk.' \\\n" + " --image \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 \\\n" + " --output outputs/wan_vace_r2v_modelopt_fp8.mp4" + ) + print( + "\n (T2V also works with this checkpoint — drop --mode r2v / --image and pass a " + "plain prompt; vace_blocks were calibrated on both zero- and real-conditioning samples.)" + ) + else: + print( + " python examples/offline_inference/vace/vace_video_generation.py \\\n" + " --mode t2v \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 \\\n" + " --output outputs/wan_vace_modelopt_fp8.mp4" + ) + print( + "\n (R2V/I2V inference will still work but vace_blocks' amax was calibrated on " + "zero vace_context only — re-run quantization with --reference-images for tighter " + "R2V scales.)" + ) + print( + "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " + "checkpoint's config.json has modelopt metadata.)" + ) + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py index 30ba621847d..1e9db1f0eab 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F @@ -31,6 +31,9 @@ from vllm_omni.diffusion.layers.rope import RotaryEmbedding from vllm_omni.diffusion.models.flux.flux_transformer import FeedForward +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + logger = init_logger(__name__) @@ -332,6 +335,8 @@ def __init__( out_bias: bool = True, eps: float = 1e-6, out_dim: int | None = None, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() @@ -350,6 +355,8 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv", ) self.to_out = nn.ModuleList( @@ -360,6 +367,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_out.0", ), nn.Identity(), # placeholder for dropout (none used) ] @@ -374,6 +383,8 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=added_proj_bias, + quant_config=quant_config, + prefix=f"{prefix}.add_kv_proj", ) self.to_add_out = RowParallelLinear( @@ -382,6 +393,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_add_out", ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -401,6 +414,8 @@ def forward( image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, hidden_states_mask: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Ensure contiguous for FP8 quantized linear layers + hidden_states = hidden_states.contiguous() qkv, _ = self.to_qkv(hidden_states) q_size = self.to_qkv.num_heads * self.head_dim kv_size = self.to_qkv.num_kv_heads * self.head_dim @@ -421,6 +436,7 @@ def forward( key = self.rope(key, cos, sin) if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.contiguous() encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) add_q_size = self.add_kv_proj.num_heads * self.head_dim add_kv_size = self.add_kv_proj.num_kv_heads * self.head_dim @@ -492,6 +508,8 @@ def __init__( attention_head_dim: int, mlp_ratio: float = 4.0, qk_norm: str = "rms_norm", + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim @@ -507,13 +525,23 @@ def __init__( out_dim=hidden_size, bias=True, eps=1e-6, + quant_config=quant_config, + prefix=f"{prefix}.attn", ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=hidden_size, dim_out=hidden_size, mult=mlp_ratio) + self.ff = FeedForward( + dim=hidden_size, dim_out=hidden_size, mult=mlp_ratio, quant_config=quant_config, prefix=f"{prefix}.ff" + ) self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=hidden_size, dim_out=hidden_size, mult=mlp_ratio) + self.ff_context = FeedForward( + dim=hidden_size, + dim_out=hidden_size, + mult=mlp_ratio, + quant_config=quant_config, + prefix=f"{prefix}.ff_context", + ) def forward( self, @@ -601,6 +629,7 @@ def __init__( target_size: int = 640, task_type: str = "i2v", use_meanflow: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -632,9 +661,14 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ HunyuanVideo15TransformerBlock( - num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + quant_config=quant_config, + prefix=f"transformer_blocks.{i}", ) - for _ in range(num_layers) + for i in range(num_layers) ] ) diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py index 6445bfee215..b007e00eed0 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py @@ -124,7 +124,9 @@ def __init__( self.scheduler._shift = od_config.flow_shift transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, HunyuanVideo15Transformer3DModel) - self.transformer = HunyuanVideo15Transformer3DModel(od_config=od_config, **transformer_kwargs) + self.transformer = HunyuanVideo15Transformer3DModel( + od_config=od_config, quant_config=od_config.quantization_config, **transformer_kwargs + ) # Check if model uses meanflow (distilled variants) self.use_meanflow = getattr(od_config.tf_model_config, "use_meanflow", False) diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py index c1acd1a895a..99b17bad424 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py @@ -153,7 +153,9 @@ def __init__( self.scheduler._shift = od_config.flow_shift transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, HunyuanVideo15Transformer3DModel) - self.transformer = HunyuanVideo15Transformer3DModel(od_config=od_config, **transformer_kwargs) + self.transformer = HunyuanVideo15Transformer3DModel( + od_config=od_config, quant_config=od_config.quantization_config, **transformer_kwargs + ) self.use_meanflow = getattr(od_config.tf_model_config, "use_meanflow", False) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py index 3f66a0d1abe..bebf4fac165 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -21,6 +21,9 @@ WanTransformerBlock, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + class VaceWanTransformerBlock(WanTransformerBlock): """VACE variant of WanTransformerBlock with proj_in/proj_out for skip connections.""" diff --git a/vllm_omni/model_executor/stage_configs/flux2_klein_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/flux2_klein_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..0b4ebe8efd4 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/flux2_klein_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running FLUX.2-klein DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: Flux2KleinPipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/flux_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/flux_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..45e4ebeff3d --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/flux_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running FLUX.1 DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: FluxPipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml new file mode 100644 index 00000000000..4220344b059 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml @@ -0,0 +1,33 @@ +# Stage config for running HunyuanVideo-1.5 DiT with ModelOpt FP8 auto-detect. +# Single GPU. Bump `tensor_parallel_size` and `devices` for multi-GPU TP. +# +# Use with a ModelOpt FP8 checkpoint (e.g. produced by +# scripts/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py). + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: HunyuanVideo15Pipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 1 + + final_output: true + final_output_type: video + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/qwen_image_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/qwen_image_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..1f0b60a7724 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen_image_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running Qwen-Image DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: QwenImagePipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml b/vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml new file mode 100644 index 00000000000..b6291073f12 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml @@ -0,0 +1,34 @@ +# Stage config for running Wan2.2 TI2V-5B DiT with ModelOpt FP8 auto-detect. +# Single GPU (TI2V-5B fits 80GB BF16; FP8 drops by ~half). +# For the A14B MoE variants, bump `tensor_parallel_size` and `devices`. +# +# Use with a ModelOpt FP8 checkpoint (e.g. produced by +# examples/quantization/quantize_wan2_2_modelopt_fp8.py). + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: Wan22TI2VPipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 1 + + final_output: true + final_output_type: video + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/z_image_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/z_image_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..7d94a18cb26 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/z_image_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running Z-Image DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: ZImagePipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1