From d43aa6ba868d499e027fd824424903250aa0770c Mon Sep 17 00:00:00 2001 From: roG0d Date: Fri, 17 Apr 2026 19:40:40 +0000 Subject: [PATCH 01/24] fix Signed-off-by: roG0d --- tests/diffusion/layers/test_rope.py | 31 +++ .../model_loader/test_diffusers_loader.py | 125 +++++++++ .../diffusion/quantization/test_fp8_config.py | 60 ++++- tests/diffusion/test_diffusion_worker.py | 30 ++- vllm_omni/diffusion/attention/layer.py | 2 - vllm_omni/diffusion/data.py | 37 ++- vllm_omni/diffusion/layers/rope.py | 20 +- .../model_loader/diffusers_loader.py | 254 ++++++++++++++++-- .../diffusion/worker/diffusion_worker.py | 39 +++ .../flux2_klein_dit_2gpu_fp8.yaml | 30 +++ .../stage_configs/flux_dit_2gpu_fp8.yaml | 30 +++ .../qwen_image_dit_2gpu_fp8.yaml | 30 +++ .../stage_configs/z_image_dit_2gpu_fp8.yaml | 30 +++ vllm_omni/quantization/factory.py | 95 ++++++- 14 files changed, 757 insertions(+), 56 deletions(-) create mode 100644 tests/diffusion/layers/test_rope.py create mode 100644 vllm_omni/model_executor/stage_configs/flux2_klein_dit_2gpu_fp8.yaml create mode 100644 vllm_omni/model_executor/stage_configs/flux_dit_2gpu_fp8.yaml create mode 100644 vllm_omni/model_executor/stage_configs/qwen_image_dit_2gpu_fp8.yaml create mode 100644 vllm_omni/model_executor/stage_configs/z_image_dit_2gpu_fp8.yaml diff --git a/tests/diffusion/layers/test_rope.py b/tests/diffusion/layers/test_rope.py new file mode 100644 index 00000000000..ca15040c432 --- /dev/null +++ b/tests/diffusion/layers/test_rope.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm_omni.diffusion.layers.rope import RotaryEmbedding + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +def test_cuda_rope_accepts_3d_query(monkeypatch): + import vllm.vllm_flash_attn.layers.rotary as rotary + + def fake_apply_rotary_emb(x, cos, sin, interleaved=False): + assert x.shape == (1, 4, 2, 8) + assert cos.shape == (4, 4) + assert sin.shape == (4, 4) + return x + 1 + + monkeypatch.setattr(rotary, "apply_rotary_emb", fake_apply_rotary_emb) + + rope = RotaryEmbedding(is_neox_style=False) + x = torch.zeros(4, 2, 8) + cos = torch.zeros(4, 4) + sin = torch.zeros(4, 4) + + out = rope.forward_cuda(x, cos, sin) + + assert out.shape == x.shape + assert torch.equal(out, torch.ones_like(x)) diff --git a/tests/diffusion/model_loader/test_diffusers_loader.py b/tests/diffusion/model_loader/test_diffusers_loader.py index 1b13c400f70..1da7f1c195a 100644 --- a/tests/diffusion/model_loader/test_diffusers_loader.py +++ b/tests/diffusion/model_loader/test_diffusers_loader.py @@ -4,6 +4,7 @@ import pytest import torch import torch.nn as nn +from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter @@ -93,3 +94,127 @@ def test_qwen_model_class_selects_qwen_gguf_adapter(): adapter = get_gguf_adapter("dummy.gguf", object(), source, od_config) assert adapter.__class__.__name__ == "QwenImageGGUFAdapter" + + +def test_loader_auto_detects_quant_config_from_transformer_config(): + od_config = type( + "Config", + (), + { + "quantization_config": None, + "tf_model_config": type( + "TransformerConfig", + (), + { + "quant_config": ModelOptFp8Config.from_config( + { + "quant_method": "modelopt", + "quant_algo": "FP8", + "ignore": [], + } + ), + "quant_method": "modelopt", + }, + )(), + "set_tf_model_config": lambda self, tf_model_config: setattr( + self, + "quantization_config", + tf_model_config.quant_config, + ), + }, + )() + + DiffusersPipelineLoader._auto_detect_quant_config(od_config) + + assert od_config.quantization_config is od_config.tf_model_config.quant_config + + +class _PackedModelOptModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.transformer = nn.Module() + self.transformer.block = nn.Module() + self.transformer.block.to_qkv = nn.Linear(2, 2, bias=False) + + +def test_modelopt_adapter_dequantizes_fp8_weight_for_full_precision_target(): + loader = object.__new__(DiffusersPipelineLoader) + model = _PackedModelOptModel() + source = DiffusersPipelineLoader.ComponentSource( + model_or_path="dummy", + subfolder="transformer", + revision=None, + prefix="transformer.", + ) + fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + loader._adapt_modelopt_fp8_weights( + model, + source, + iter( + [ + ("transformer.block.to_q.weight_scale", scale), + ("transformer.block.to_q.input_scale", torch.tensor([1.0])), + ("transformer.block.to_q.weight", fp8_weight), + ] + ), + {"transformer.block.to_q.weight_scale": scale}, + ) + ) + + assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"] + assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype + assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) + + +class _QuantizedPackedModelOptModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.transformer = nn.Module() + self.transformer.block = nn.Module() + self.transformer.block.to_qkv = nn.Module() + self.transformer.block.to_qkv.register_parameter( + "weight", + nn.Parameter(torch.empty(2, 2, dtype=torch.float8_e4m3fn), requires_grad=False), + ) + self.transformer.block.to_qkv.register_parameter( + "weight_scale", + nn.Parameter(torch.empty(1), requires_grad=False), + ) + self.transformer.block.to_qkv.register_parameter( + "input_scale", + nn.Parameter(torch.empty(1), requires_grad=False), + ) + + +def test_modelopt_adapter_keeps_scale_tensors_for_quantized_target(): + loader = object.__new__(DiffusersPipelineLoader) + model = _QuantizedPackedModelOptModel() + source = DiffusersPipelineLoader.ComponentSource( + model_or_path="dummy", + subfolder="transformer", + revision=None, + prefix="transformer.", + ) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + loader._adapt_modelopt_fp8_weights( + model, + source, + iter( + [ + ("transformer.block.to_q.weight_scale", scale), + ("transformer.block.to_q.input_scale", torch.tensor([1.0])), + ] + ), + {"transformer.block.to_q.weight_scale": scale}, + ) + ) + + assert [name for name, _ in adapted] == [ + "transformer.block.to_q.weight_scale", + "transformer.block.to_q.input_scale", + ] diff --git a/tests/diffusion/quantization/test_fp8_config.py b/tests/diffusion/quantization/test_fp8_config.py index 574af7a6699..a85c4925358 100644 --- a/tests/diffusion/quantization/test_fp8_config.py +++ b/tests/diffusion/quantization/test_fp8_config.py @@ -54,6 +54,44 @@ def test_build_quant_config_dict_not_mutated(): assert original == copy +def test_build_quant_config_modelopt_fp8_config_json(): + from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config + + from vllm_omni.quantization import build_quant_config + + config = build_quant_config( + { + "quant_method": "modelopt", + "quant_algo": "FP8", + "ignore": ["proj_out"], + "producer": {"name": "modelopt"}, + } + ) + + assert isinstance(config, ModelOptFp8Config) + assert config.get_name() == "modelopt" + assert config.is_checkpoint_fp8_serialized + + +def test_build_quant_config_modelopt_nested_checkpoint_metadata(): + from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config + + from vllm_omni.quantization import build_quant_config + + config = build_quant_config( + { + "producer": {"name": "modelopt"}, + "quantization": { + "quant_algo": "FP8", + "exclude_modules": ["proj_out"], + }, + } + ) + + assert isinstance(config, ModelOptFp8Config) + assert config.get_name() == "modelopt" + + def test_build_quant_config_per_component(): from vllm_omni.quantization import ComponentQuantizationConfig, build_quant_config @@ -91,7 +129,7 @@ def test_flat_dict_not_misdetected_as_per_component(): as a per-component dict — it should raise ValueError for missing 'method'.""" from vllm_omni.quantization import build_quant_config - with pytest.raises(ValueError, match="must have a 'method' key"): + with pytest.raises(ValueError, match="must have a 'method' or 'quant_method' key"): build_quant_config({"activation_scheme": "static"}) @@ -194,6 +232,26 @@ def test_integration_per_component(): assert config.quantization_config.component_configs["vae"] is None +def test_transformer_config_auto_detects_modelopt_fp8(): + from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config + + from vllm_omni.diffusion.data import TransformerConfig + + config = TransformerConfig.from_dict( + { + "_class_name": "FluxTransformer2DModel", + "quantization_config": { + "quant_method": "modelopt", + "quant_algo": "FP8", + "ignore": ["proj_out"], + }, + } + ) + + assert isinstance(config.quant_config, ModelOptFp8Config) + assert config.quant_method == "modelopt" + + def test_supported_methods_includes_vllm(): from vllm_omni.quantization import SUPPORTED_QUANTIZATION_METHODS diff --git a/tests/diffusion/test_diffusion_worker.py b/tests/diffusion/test_diffusion_worker.py index e2bd7ef8a32..e543e8b6e98 100644 --- a/tests/diffusion/test_diffusion_worker.py +++ b/tests/diffusion/test_diffusion_worker.py @@ -14,7 +14,7 @@ import torch from pytest_mock import MockerFixture -from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker +from vllm_omni.diffusion.worker.diffusion_worker import DiffusionWorker, _make_diffusion_vllm_model_config pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] @@ -78,6 +78,34 @@ def test_load_weights_empty_iterable(self, mocker: MockerFixture, mock_gpu_worke assert result == set() +def test_diffusion_vllm_model_config_supplies_dtype_for_quant_methods(): + from types import SimpleNamespace + + from vllm_omni.quantization import build_quant_config + + od_config = SimpleNamespace( + model="dummy", + dtype=torch.bfloat16, + quantization_config=build_quant_config( + { + "quant_method": "modelopt", + "quant_algo": "FP8", + "ignore": [], + } + ), + tf_model_config=SimpleNamespace(), + enforce_eager=True, + is_moe=False, + ) + + model_config = _make_diffusion_vllm_model_config(od_config) + + assert model_config.dtype is torch.bfloat16 + assert model_config.quantization == "modelopt" + assert model_config.quantization_config is od_config.quantization_config + assert model_config.is_quantized() + + class TestDiffusionWorkerSleep: """Test DiffusionWorker.sleep method.""" diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 4fdf2ff1612..72ff0fc5fa6 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -5,8 +5,6 @@ # DeepSpeed Team & Jiarui Fang # Adapted from # https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py - - import torch import torch.nn as nn from vllm.logger import init_logger diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 0a19eb11974..211e91d9ca7 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -207,10 +207,11 @@ def from_dict(cls, data: dict[str, Any]) -> "TransformerConfig": quant_method: str | None = None quant_config: QuantizationConfig | None = None disk_qc = params.get("quantization_config") - if isinstance(disk_qc, dict) and "quant_method" in disk_qc: - quant_method = disk_qc["quant_method"] - kwargs = {k: v for k, v in disk_qc.items() if k != "quant_method"} - quant_config = build_quant_config(quant_method, **kwargs) + if isinstance(disk_qc, dict): + raw_quant_method = disk_qc.get("quant_method", disk_qc.get("method")) + quant_config = build_quant_config(disk_qc) + if quant_config is not None: + quant_method = raw_quant_method if raw_quant_method is not None else quant_config.get_name() return cls(params=params, quant_method=quant_method, quant_config=quant_config) @@ -616,14 +617,9 @@ def __post_init__(self): # Auto-detect quantization from TransformerConfig if not explicitly set. # This covers the case where tf_model_config is passed at construction - # time. For late (post-construction) assignment, callers should use + # time. For late (post-construction) assignment, callers should use # set_tf_model_config() which propagates quant_config automatically. - if self.quantization_config is None and self.tf_model_config.quant_config is not None: - self.quantization_config = self.tf_model_config.quant_config - logger.info( - "Auto-detected quantization '%s' from model config", - self.tf_model_config.quant_method, - ) + self._propagate_quantization_from_tf_config(self.tf_model_config) # Resolve quantization_config: str/dict -> QuantizationConfig via build_quant_config. if self.quantization_config is not None: @@ -644,6 +640,14 @@ def __post_init__(self): elif self.max_cpu_loras < 1: raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") + def _propagate_quantization_from_tf_config(self, tf_config: "TransformerConfig") -> None: + if self.quantization_config is None and tf_config.quant_config is not None: + self.quantization_config = tf_config.quant_config + logger.info( + "Auto-detected quantization '%s' from model config", + tf_config.quant_method, + ) + def set_tf_model_config(self, tf_config: "TransformerConfig") -> None: """Assign `tf_model_config` and propagate quantization if detected. @@ -659,12 +663,7 @@ def set_tf_model_config(self, tf_config: "TransformerConfig") -> None: `TransformerConfig.from_dict`. """ self.tf_model_config = tf_config - if self.quantization_config is None and tf_config.quant_config is not None: - self.quantization_config = tf_config.quant_config - logger.info( - "Auto-detected quantization '%s' from model config", - tf_config.quant_method, - ) + self._propagate_quantization_from_tf_config(tf_config) def update_multimodal_support(self) -> None: # Resolve serving-visible multimodal behavior from shared metadata @@ -690,7 +689,7 @@ def enrich_config(self) -> None: self.update_multimodal_support() tf_config_dict = get_hf_file_to_dict("transformer/config.json", self.model) - self.tf_model_config = TransformerConfig.from_dict(tf_config_dict) + self.set_tf_model_config(TransformerConfig.from_dict(tf_config_dict)) else: raise FileNotFoundError("model_index.json not found") except (AttributeError, OSError, ValueError, FileNotFoundError): @@ -698,7 +697,7 @@ def enrich_config(self) -> None: if cfg is None: raise ValueError(f"Could not find config.json or model_index.json for model {self.model}") - self.tf_model_config = TransformerConfig.from_dict(cfg) + self.set_tf_model_config(TransformerConfig.from_dict(cfg)) model_type = cfg.get("model_type") architectures = cfg.get("architectures") or [] diff --git a/vllm_omni/diffusion/layers/rope.py b/vllm_omni/diffusion/layers/rope.py index 61ddb4d84af..517aeec1867 100644 --- a/vllm_omni/diffusion/layers/rope.py +++ b/vllm_omni/diffusion/layers/rope.py @@ -65,6 +65,18 @@ def apply_rotary_emb_mindiesd( return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_half", head_first=False, fused=True) +def _ensure_batch_dim(x: torch.Tensor) -> tuple[torch.Tensor, bool]: + if x.dim() == 3: + return x.unsqueeze(0), True + return x, False + + +def _restore_batch_dim(x: torch.Tensor, squeezed: bool) -> torch.Tensor: + if squeezed: + return x.squeeze(0) + return x + + class RotaryEmbedding(CustomOp): """ rotary positional embedding. @@ -98,12 +110,14 @@ def forward_cuda( cos = cos[0] sin = sin[0] - return apply_rotary_emb( + x, squeezed = _ensure_batch_dim(x) + output = apply_rotary_emb( x, cos, sin, interleaved=self.interleaved, ) + return _restore_batch_dim(output, squeezed) def forward_hip( self, @@ -119,12 +133,14 @@ def forward_hip( cos = cos[0] sin = sin[0] - return self.apply_rotary_emb_flash_attn( + x, squeezed = _ensure_batch_dim(x) + output = self.apply_rotary_emb_flash_attn( x, cos, sin, interleaved=self.interleaved, ) + return _restore_batch_dim(output, squeezed) def forward_npu( self, diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 146afb26fbc..c4559658682 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -7,10 +7,11 @@ import time from collections.abc import Generator, Iterable from pathlib import Path -from typing import TYPE_CHECKING, cast +from typing import cast import torch from huggingface_hub import hf_hub_download +from safetensors import safe_open from torch import nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig @@ -34,9 +35,6 @@ from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter from vllm_omni.diffusion.registry import initialize_model -if TYPE_CHECKING: - from vllm_omni.diffusion.data import OmniDiffusionConfig - logger = init_logger(__name__) @@ -48,6 +46,22 @@ def _natural_sort_key(filepath: str) -> list: MODEL_INDEX = "model_index.json" DIFFUSION_MODEL_WEIGHTS_INDEX = "diffusion_pytorch_model.safetensors.index.json" +MODEL_OPT_SCALE_SUFFIXES = (".input_scale", ".weight_scale", ".weight_scale_inv") +MODEL_OPT_PACKED_MODULES_MAPPING = { + "to_qkv": ("to_q", "to_k", "to_v"), + "add_kv_proj": ("add_q_proj", "add_k_proj", "add_v_proj"), + "w13": ("w1", "w3"), +} +FP8_DTYPES = tuple( + dtype + for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2fnuz", None), + ) + if dtype is not None +) class DiffusersPipelineLoader: @@ -172,7 +186,11 @@ def _prepare_weights( return hf_folder, hf_weights_files, use_safetensors - def _get_weights_iterator(self, source: "ComponentSource") -> Generator[tuple[str, torch.Tensor], None, None]: + def _get_weights_iterator( + self, + source: "ComponentSource", + model: nn.Module | None = None, + ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" _, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, @@ -208,7 +226,184 @@ def _get_weights_iterator(self, source: "ComponentSource") -> Generator[tuple[st if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. - return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + prefixed_weights_iterator = ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + if model is not None and self._should_adapt_modelopt_fp8_weights(source, use_safetensors): + scale_tensors = self._collect_modelopt_scale_tensors(hf_weights_files, source.prefix) + return self._adapt_modelopt_fp8_weights(model, source, prefixed_weights_iterator, scale_tensors) + return prefixed_weights_iterator + + def _get_source_quant_config(self, source: "ComponentSource") -> object | None: + od_config = self.od_config + if od_config is None: + return None + + quant_config = od_config.quantization_config + if quant_config is None: + return None + + if hasattr(quant_config, "resolve"): + return quant_config.resolve(source.prefix.rstrip(".")) + return quant_config + + @staticmethod + def _is_modelopt_fp8_checkpoint_quant_config(quant_config: object) -> bool: + return ( + hasattr(quant_config, "get_name") + and quant_config.get_name() == "modelopt" + and bool(getattr(quant_config, "is_checkpoint_fp8_serialized", False)) + ) + + def _should_adapt_modelopt_fp8_weights(self, source: "ComponentSource", use_safetensors: bool) -> bool: + if not use_safetensors or not self._is_transformer_source(source): + return False + + quant_config = self._get_source_quant_config(source) + if quant_config is None: + return False + return self._is_modelopt_fp8_checkpoint_quant_config(quant_config) + + @staticmethod + def _is_modelopt_scale(name: str) -> bool: + return name.endswith(MODEL_OPT_SCALE_SUFFIXES) + + @staticmethod + def _is_fp8_tensor(tensor: torch.Tensor) -> bool: + return tensor.dtype in FP8_DTYPES + + @staticmethod + def _get_weight_scale_name(weight_name: str) -> str | None: + if weight_name.endswith(".weight"): + return weight_name[: -len(".weight")] + ".weight_scale" + return None + + @staticmethod + def _replace_module_name(name: str, old: str, new: str) -> str: + if name.startswith(f"{old}."): + return f"{new}.{name[len(old) + 1 :]}" + return name.replace(f".{old}.", f".{new}.") + + def _get_modelopt_packed_name_pairs(self, model: nn.Module) -> tuple[tuple[str, str], ...]: + mapping: dict[str, tuple[str, ...]] = dict(MODEL_OPT_PACKED_MODULES_MAPPING) + for _, module in model.named_modules(): + packed_mapping = getattr(module, "packed_modules_mapping", None) + if isinstance(packed_mapping, dict): + for packed_name, shard_names in packed_mapping.items(): + if isinstance(shard_names, (list, tuple)): + mapping[str(packed_name)] = tuple(str(shard_name) for shard_name in shard_names) + + pairs: list[tuple[str, str]] = [] + for packed_name, shard_names in mapping.items(): + pairs.extend((packed_name, shard_name) for shard_name in shard_names) + return tuple(pairs) + + def _resolve_modelopt_target_name( + self, + name: str, + loadable_names: set[str], + packed_name_pairs: tuple[tuple[str, str], ...], + ) -> str | None: + if name in loadable_names: + return name + + if ".to_out.0." in name: + candidate = name.replace(".to_out.0.", ".to_out.") + if candidate in loadable_names: + return candidate + + for packed_name, shard_name in packed_name_pairs: + candidate = self._replace_module_name(name, shard_name, packed_name) + if candidate != name and candidate in loadable_names: + return candidate + return None + + @staticmethod + def _reshape_modelopt_weight_scale(scale: torch.Tensor, weight_shape: torch.Size) -> torch.Tensor: + if scale.numel() == 1: + return scale.reshape(()) + if len(weight_shape) == 2 and scale.ndim == 1 and scale.shape[0] == weight_shape[0]: + return scale.reshape(-1, 1) + if tuple(scale.shape) == tuple(weight_shape): + return scale + if ( + len(weight_shape) == 2 + and scale.ndim == 4 + and scale.shape[1] == 1 + and scale.shape[3] == 1 + and weight_shape[0] % scale.shape[0] == 0 + and weight_shape[1] % scale.shape[2] == 0 + ): + block_n = weight_shape[0] // scale.shape[0] + block_k = weight_shape[1] // scale.shape[2] + return scale.expand(scale.shape[0], block_n, scale.shape[2], block_k).reshape(weight_shape) + raise ValueError(f"Unsupported ModelOpt FP8 weight_scale shape {tuple(scale.shape)} for weight {weight_shape}") + + def _dequantize_modelopt_fp8_weight( + self, + name: str, + loaded_weight: torch.Tensor, + scale_tensors: dict[str, torch.Tensor], + target_dtype: torch.dtype, + ) -> torch.Tensor: + scale_name = self._get_weight_scale_name(name) + if scale_name is None or scale_name not in scale_tensors: + raise ValueError(f"Missing ModelOpt FP8 weight_scale for full-precision target weight {name!r}") + + weight = loaded_weight.to(dtype=torch.float32) + scale = scale_tensors[scale_name].to(dtype=torch.float32, device=weight.device) + scale = self._reshape_modelopt_weight_scale(scale, loaded_weight.shape) + return (weight * scale).to(dtype=target_dtype) + + def _collect_modelopt_scale_tensors( + self, + hf_weights_files: list[str], + prefix: str, + ) -> dict[str, torch.Tensor]: + scale_tensors: dict[str, torch.Tensor] = {} + for filename in sorted(hf_weights_files, key=_natural_sort_key): + if not filename.endswith(".safetensors"): + continue + with safe_open(filename, framework="pt", device="cpu") as f: + for name in f.keys(): + if self._is_modelopt_scale(name): + scale_tensors[prefix + name] = f.get_tensor(name) + return scale_tensors + + def _adapt_modelopt_fp8_weights( + self, + model: nn.Module, + source: "ComponentSource", + weights: Iterable[tuple[str, torch.Tensor]], + scale_tensors: dict[str, torch.Tensor], + ) -> Generator[tuple[str, torch.Tensor], None, None]: + loadable_tensors = self._get_model_loadable_tensors(model) + loadable_names = set(loadable_tensors) + packed_name_pairs = self._get_modelopt_packed_name_pairs(model) + + skipped_scales = 0 + dequantized_weights = 0 + for name, tensor in weights: + target_name = self._resolve_modelopt_target_name(name, loadable_names, packed_name_pairs) + if self._is_modelopt_scale(name): + if target_name is None: + skipped_scales += 1 + continue + yield name, tensor + continue + + if self._is_fp8_tensor(tensor) and target_name is not None: + target_tensor = loadable_tensors[target_name] + if target_tensor.dtype not in FP8_DTYPES: + tensor = self._dequantize_modelopt_fp8_weight(name, tensor, scale_tensors, target_tensor.dtype) + dequantized_weights += 1 + yield name, tensor + + if skipped_scales or dequantized_weights: + logger.info_once( + "Adapted ModelOpt FP8 %s weights: dequantized %d full-precision weights, skipped %d scale tensors", + source.prefix or source.subfolder or "model", + dequantized_weights, + skipped_scales, + ) def get_all_weights( self, @@ -216,7 +411,7 @@ def get_all_weights( ) -> Generator[tuple[str, torch.Tensor], None, None]: sources = self._get_weight_sources(model) for source in sources: - yield from self._get_weights_iterator(source) + yield from self._get_weights_iterator(source, model=model) def _get_weight_sources(self, model: nn.Module) -> tuple["ComponentSource", ...]: return tuple( @@ -262,6 +457,9 @@ def load_model( device: torch.device | None = None, ) -> nn.Module: """Load a model with the given configurations.""" + self.od_config = od_config + self._auto_detect_quant_config(od_config) + # CPU offload + FP8: load weights on device for FP8 quantization if load_device == "cpu" and od_config.quantization_config is not None: load_device = device.type @@ -275,11 +473,7 @@ def load_model( ) else: with target_device: - if load_format == "default": - model = initialize_model(od_config) - elif load_format == "custom_pipeline": - model_cls = resolve_obj_by_qualname(custom_pipeline_name) - model = model_cls(od_config=od_config) + model = self._initialize_pipeline_model(od_config, load_format, custom_pipeline_name) logger.debug("Loading weights on %s ...", load_device) if self._is_gguf_quantization(od_config): self._load_weights_with_gguf(model, od_config) @@ -293,6 +487,14 @@ def load_model( return model.eval() + @staticmethod + def _auto_detect_quant_config(od_config: OmniDiffusionConfig) -> None: + if od_config.quantization_config is not None: + return + tf_model_config = getattr(od_config, "tf_model_config", None) + if getattr(tf_model_config, "quant_config", None) is not None: + od_config.set_tf_model_config(tf_model_config) + def _process_weights_after_loading(self, model: nn.Module, target_device: torch.device) -> None: """Process weights after loading for quantization methods. @@ -429,11 +631,27 @@ def _is_transformer_source(self, source: "ComponentSource") -> bool: return source.prefix.startswith("transformer.") def _get_model_loadable_names(self, model: nn.Module) -> set[str]: + return set(self._get_model_loadable_tensors(model)) + + def _get_model_loadable_tensors(self, model: nn.Module) -> dict[str, torch.Tensor]: # Avoid model.state_dict() here because GGUF uses UninitializedParameter # which raises during detach(). Collect names directly. - names = {name for name, _ in model.named_parameters()} - names.update(name for name, _ in model.named_buffers()) - return names + loadable_tensors: dict[str, torch.Tensor] = {name: param for name, param in model.named_parameters()} + loadable_tensors.update({name: buffer for name, buffer in model.named_buffers()}) + return loadable_tensors + + @staticmethod + def _initialize_pipeline_model( + od_config: OmniDiffusionConfig, + load_format: str, + custom_pipeline_name: str | None, + ) -> nn.Module: + if load_format == "default": + return initialize_model(od_config) + if load_format == "custom_pipeline": + model_cls = resolve_obj_by_qualname(custom_pipeline_name) + return model_cls(od_config=od_config) + raise ValueError(f"Unknown load_format: {load_format}") def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str: if os.path.isfile(gguf_model): @@ -538,11 +756,7 @@ def _load_model_with_hsdp( # directly on GPU, HSDP needs weights on CPU first so they can be redistributed # across GPUs by apply_hsdp_to_model. The model's load_weights handles weight # mapping (QKV fusion, etc.). - if load_format == "default": - model = initialize_model(od_config) - elif load_format == "custom_pipeline": - model_cls = resolve_obj_by_qualname(custom_pipeline_name) - model = model_cls(od_config=od_config) + model = self._initialize_pipeline_model(od_config, load_format, custom_pipeline_name) self.load_weights(model) # Collect all transformers to shard (some models have transformer_2 for MoE) diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 160309e0d8d..a5740544c40 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -13,6 +13,7 @@ import os from collections.abc import Iterable from contextlib import AbstractContextManager, nullcontext +from dataclasses import dataclass from typing import Any import torch @@ -49,6 +50,42 @@ logger = init_logger(__name__) +@dataclass +class _DiffusionVllmModelConfig: + model: str + dtype: torch.dtype + quantization: str | None = None + quantization_config: Any | None = None + hf_config: Any | None = None + multimodal_config: Any | None = None + enforce_eager: bool = False + disable_cascade_attn: bool = False + is_moe: bool = False + + def is_quantized(self) -> bool: + return self.quantization is not None + + def is_model_moe(self) -> bool: + return self.is_moe + + def is_nvfp4_quantized(self) -> bool: + return self.quantization == "modelopt_fp4" + + +def _make_diffusion_vllm_model_config(od_config: OmniDiffusionConfig) -> _DiffusionVllmModelConfig: + quant_config = getattr(od_config, "quantization_config", None) + quantization = quant_config.get_name() if quant_config is not None and hasattr(quant_config, "get_name") else None + return _DiffusionVllmModelConfig( + model=od_config.model, + dtype=od_config.dtype, + quantization=quantization, + quantization_config=quant_config, + hf_config=getattr(od_config, "tf_model_config", None), + enforce_eager=getattr(od_config, "enforce_eager", False), + is_moe=bool(getattr(od_config, "is_moe", False)), + ) + + class DiffusionWorker: """ A worker that manages GPU infrastructure and delegates to the model runner. @@ -116,6 +153,8 @@ def init_device(self) -> None: vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel vllm_config.profiler_config = self.od_config.profiler_config + vllm_config.model_config = _make_diffusion_vllm_model_config(self.od_config) # type: ignore[assignment] + vllm_config.quant_config = getattr(self.od_config, "quantization_config", None) self.vllm_config = vllm_config # Initialize distributed environment diff --git a/vllm_omni/model_executor/stage_configs/flux2_klein_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/flux2_klein_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..0b4ebe8efd4 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/flux2_klein_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running FLUX.2-klein DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: Flux2KleinPipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/flux_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/flux_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..45e4ebeff3d --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/flux_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running FLUX.1 DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: FluxPipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/qwen_image_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/qwen_image_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..1f0b60a7724 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/qwen_image_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running Qwen-Image DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: QwenImagePipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/model_executor/stage_configs/z_image_dit_2gpu_fp8.yaml b/vllm_omni/model_executor/stage_configs/z_image_dit_2gpu_fp8.yaml new file mode 100644 index 00000000000..7d94a18cb26 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/z_image_dit_2gpu_fp8.yaml @@ -0,0 +1,30 @@ +# Stage config for running Z-Image DiT with ModelOpt FP8 auto-detect. +# The following config is for 2 GPUs. + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0,1" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: ZImagePipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 2 + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py index f85589d69bb..845d5297ecf 100644 --- a/vllm_omni/quantization/factory.py +++ b/vllm_omni/quantization/factory.py @@ -60,17 +60,84 @@ def _build_inc(**kw: Any) -> QuantizationConfig: "int8": _build_int8, "inc": _build_inc, "auto-round": _build_inc, + "auto_round": _build_inc, } SUPPORTED_QUANTIZATION_METHODS: list[str] = list(dict.fromkeys(QUANTIZATION_METHODS + list(_OVERRIDES.keys()))) +_MODEL_OPT_METHODS = { + "modelopt", + "modelopt_fp4", + "modelopt_mxfp8", + "modelopt_mixed", +} +_MODEL_OPT_FP8_ALGOS = { + "FP8", + "FP8_PER_CHANNEL_PER_TOKEN", + "FP8_PB_WO", +} + + +def _normalize_method_name(method: Any) -> str: + return str(method).lower().replace("-", "_") + + +def _detect_modelopt_method(config: Mapping[str, Any]) -> str | None: + method = config.get("method", config.get("quant_method")) + if method is not None: + normalized_method = _normalize_method_name(method) + if normalized_method in _MODEL_OPT_METHODS: + return normalized_method + + producer = config.get("producer") + if isinstance(producer, Mapping) and str(producer.get("name", "")).lower() == "modelopt": + quantization = config.get("quantization") + if isinstance(quantization, Mapping): + quant_algo = str(quantization.get("quant_algo", "")).upper() + else: + quant_algo = str(config.get("quant_algo", "")).upper() + + if quant_algo in _MODEL_OPT_FP8_ALGOS: + return "modelopt" + if quant_algo == "NVFP4": + return "modelopt_fp4" + if quant_algo == "MXFP8": + return "modelopt_mxfp8" + if quant_algo == "MIXED_PRECISION": + return "modelopt_mixed" + + return None + + +def _build_modelopt_from_config(method: str, config: Mapping[str, Any]) -> QuantizationConfig: + config_cls = get_quantization_config(method) + normalized_config = dict(config) + normalized_config.setdefault("quant_method", method) + return config_cls.from_config(normalized_config) + + +def _pop_method_name(spec: dict[str, Any]) -> str | None: + method = spec.pop("method", None) + if method is None: + method = spec.pop("quant_method", None) + return method + + +def _build_from_method_and_config(method: str, config: Mapping[str, Any]) -> QuantizationConfig: + normalized_config = {"quant_method": method, **config} + modelopt_method = _detect_modelopt_method(normalized_config) + if modelopt_method is not None: + return _build_modelopt_from_config(modelopt_method, normalized_config) + return _build_single(method, **config) + + def _build_single(method: str, **kwargs: Any) -> QuantizationConfig: """Build a single QuantizationConfig by method name. Resolution: _OVERRIDES first, then vLLM registry via from_config(). """ - method = method.lower() + method = _normalize_method_name(method) if method in _OVERRIDES: return _OVERRIDES[method](**kwargs) @@ -92,16 +159,17 @@ def _build_single(method: str, **kwargs: Any) -> QuantizationConfig: def _is_per_component_dict(spec: dict[str, Any]) -> bool: """Check if a dict describes per-component quantization. - A per-component dict has no "method" key and all values are + A per-component dict has no "method" / "quant_method" key and all values are str, dict, or None. To avoid misdetecting a flat config with all-string values (e.g. {"activation_scheme": "static"}), we - require at least one value to be None or a dict with "method". + require at least one value to be None or a dict with "method" / + "quant_method". """ - if "method" in spec: + if "method" in spec or "quant_method" in spec: return False if not all(isinstance(v, (dict, str, type(None))) for v in spec.values()): return False - return any(v is None or (isinstance(v, dict) and "method" in v) for v in spec.values()) + return any(v is None or (isinstance(v, dict) and ("method" in v or "quant_method" in v)) for v in spec.values()) def _build_component_config(spec: dict[str, Any]) -> ComponentQuantizationConfig: @@ -116,10 +184,10 @@ def _build_component_config(spec: dict[str, Any]) -> ComponentQuantizationConfig config = _build_single(value) elif isinstance(value, dict): value = dict(value) # avoid mutating caller's dict - method = value.pop("method", None) + method = _pop_method_name(value) if method is None: - raise ValueError(f"Component '{prefix}' config dict must have a 'method' key") - config = _build_single(method, **value) + raise ValueError(f"Component '{prefix}' config dict must have a 'method' or 'quant_method' key") + config = _build_from_method_and_config(method, value) else: raise TypeError(f"Component '{prefix}' config must be str, dict, or None, got {type(value).__name__}") @@ -164,14 +232,19 @@ def build_quant_config( if _is_per_component_dict(spec): return _build_component_config(spec) - method = spec.pop("method", None) + modelopt_method = _detect_modelopt_method(spec) + if modelopt_method is not None: + logger.info("Building quantization config: %s", modelopt_method) + return _build_modelopt_from_config(modelopt_method, spec) + + method = _pop_method_name(spec) if method is None: raise ValueError( - "Dict quantization config must have a 'method' key or " + "Dict quantization config must have a 'method' or 'quant_method' key or " "be a per-component config with component prefixes as keys." ) merged = {**spec, **kwargs} logger.info("Building quantization config: %s", method) - return _build_single(method, **merged) + return _build_from_method_and_config(method, merged) raise TypeError(f"quantization config must be str, dict, QuantizationConfig, or None, got {type(spec).__name__}") From 9846e09eaf3841b19bcbadb53238f6cacdc4cbe1 Mon Sep 17 00:00:00 2001 From: roG0d Date: Fri, 17 Apr 2026 19:43:59 +0000 Subject: [PATCH 02/24] fix Signed-off-by: roG0d --- tests/diffusion/layers/test_rope.py | 31 ----------------------------- 1 file changed, 31 deletions(-) delete mode 100644 tests/diffusion/layers/test_rope.py diff --git a/tests/diffusion/layers/test_rope.py b/tests/diffusion/layers/test_rope.py deleted file mode 100644 index ca15040c432..00000000000 --- a/tests/diffusion/layers/test_rope.py +++ /dev/null @@ -1,31 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm_omni.diffusion.layers.rope import RotaryEmbedding - -pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] - - -def test_cuda_rope_accepts_3d_query(monkeypatch): - import vllm.vllm_flash_attn.layers.rotary as rotary - - def fake_apply_rotary_emb(x, cos, sin, interleaved=False): - assert x.shape == (1, 4, 2, 8) - assert cos.shape == (4, 4) - assert sin.shape == (4, 4) - return x + 1 - - monkeypatch.setattr(rotary, "apply_rotary_emb", fake_apply_rotary_emb) - - rope = RotaryEmbedding(is_neox_style=False) - x = torch.zeros(4, 2, 8) - cos = torch.zeros(4, 4) - sin = torch.zeros(4, 4) - - out = rope.forward_cuda(x, cos, sin) - - assert out.shape == x.shape - assert torch.equal(out, torch.ones_like(x)) From 12c8e5fb0784d58d824aa2c8f7a1ebfad56f1591 Mon Sep 17 00:00:00 2001 From: roG0d Date: Sat, 18 Apr 2026 03:17:23 +0000 Subject: [PATCH 03/24] refactoring Signed-off-by: roG0d --- .../model_loader/test_diffusers_loader.py | 72 ++++++++++- vllm_omni/diffusion/attention/layer.py | 2 + .../model_loader/diffusers_loader.py | 113 +++++++++--------- 3 files changed, 130 insertions(+), 57 deletions(-) diff --git a/tests/diffusion/model_loader/test_diffusers_loader.py b/tests/diffusion/model_loader/test_diffusers_loader.py index 1da7f1c195a..b94b663ed5e 100644 --- a/tests/diffusion/model_loader/test_diffusers_loader.py +++ b/tests/diffusion/model_loader/test_diffusers_loader.py @@ -160,7 +160,37 @@ def test_modelopt_adapter_dequantizes_fp8_weight_for_full_precision_target(): ("transformer.block.to_q.weight", fp8_weight), ] ), - {"transformer.block.to_q.weight_scale": scale}, + ) + ) + + assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"] + assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype + assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) + + +def test_modelopt_adapter_dequantizes_fp8_weight_when_scale_arrives_late(): + loader = object.__new__(DiffusersPipelineLoader) + model = _PackedModelOptModel() + source = DiffusersPipelineLoader.ComponentSource( + model_or_path="dummy", + subfolder="transformer", + revision=None, + prefix="transformer.", + ) + fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + loader._adapt_modelopt_fp8_weights( + model, + source, + iter( + [ + ("transformer.block.to_q.weight", fp8_weight), + ("transformer.block.to_q.weight_scale", scale), + ("transformer.block.to_q.input_scale", torch.tensor([1.0])), + ] + ), ) ) @@ -189,6 +219,45 @@ def __init__(self) -> None: ) +class _ChildPackedModelOptModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.transformer = nn.Module() + self.transformer.packed_modules_mapping = {"packed_proj": ["proj_a", "proj_b"]} + self.transformer.block = nn.Module() + self.transformer.block.packed_proj = nn.Linear(2, 2, bias=False) + + +def test_modelopt_adapter_uses_child_packed_modules_mapping(): + loader = object.__new__(DiffusersPipelineLoader) + model = _ChildPackedModelOptModel() + source = DiffusersPipelineLoader.ComponentSource( + model_or_path="dummy", + subfolder="transformer", + revision=None, + prefix="transformer.", + ) + fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + loader._adapt_modelopt_fp8_weights( + model, + source, + iter( + [ + ("transformer.block.proj_a.weight", fp8_weight), + ("transformer.block.proj_a.weight_scale", scale), + ] + ), + ) + ) + + assert [name for name, _ in adapted] == ["transformer.block.proj_a.weight"] + assert adapted[0][1].dtype == model.transformer.block.packed_proj.weight.dtype + assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) + + def test_modelopt_adapter_keeps_scale_tensors_for_quantized_target(): loader = object.__new__(DiffusersPipelineLoader) model = _QuantizedPackedModelOptModel() @@ -210,7 +279,6 @@ def test_modelopt_adapter_keeps_scale_tensors_for_quantized_target(): ("transformer.block.to_q.input_scale", torch.tensor([1.0])), ] ), - {"transformer.block.to_q.weight_scale": scale}, ) ) diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 72ff0fc5fa6..4fdf2ff1612 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -5,6 +5,8 @@ # DeepSpeed Team & Jiarui Fang # Adapted from # https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py + + import torch import torch.nn as nn from vllm.logger import init_logger diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index c4559658682..1f5daacbcd0 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -11,7 +11,6 @@ import torch from huggingface_hub import hf_hub_download -from safetensors import safe_open from torch import nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig @@ -27,6 +26,7 @@ multi_thread_safetensors_weights_iterator, safetensors_weights_iterator, ) +from vllm.model_executor.utils import get_packed_modules_mapping from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import set_default_torch_dtype @@ -228,19 +228,14 @@ def _get_weights_iterator( # Apply the prefix. prefixed_weights_iterator = ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) if model is not None and self._should_adapt_modelopt_fp8_weights(source, use_safetensors): - scale_tensors = self._collect_modelopt_scale_tensors(hf_weights_files, source.prefix) - return self._adapt_modelopt_fp8_weights(model, source, prefixed_weights_iterator, scale_tensors) + return self._adapt_modelopt_fp8_weights(model, source, prefixed_weights_iterator) return prefixed_weights_iterator def _get_source_quant_config(self, source: "ComponentSource") -> object | None: - od_config = self.od_config - if od_config is None: - return None - - quant_config = od_config.quantization_config - if quant_config is None: + if self.od_config is None: return None + quant_config = self.od_config.quantization_config if hasattr(quant_config, "resolve"): return quant_config.resolve(source.prefix.rstrip(".")) return quant_config @@ -276,31 +271,23 @@ def _get_weight_scale_name(weight_name: str) -> str | None: return weight_name[: -len(".weight")] + ".weight_scale" return None - @staticmethod - def _replace_module_name(name: str, old: str, new: str) -> str: - if name.startswith(f"{old}."): - return f"{new}.{name[len(old) + 1 :]}" - return name.replace(f".{old}.", f".{new}.") - - def _get_modelopt_packed_name_pairs(self, model: nn.Module) -> tuple[tuple[str, str], ...]: - mapping: dict[str, tuple[str, ...]] = dict(MODEL_OPT_PACKED_MODULES_MAPPING) - for _, module in model.named_modules(): - packed_mapping = getattr(module, "packed_modules_mapping", None) - if isinstance(packed_mapping, dict): - for packed_name, shard_names in packed_mapping.items(): - if isinstance(shard_names, (list, tuple)): - mapping[str(packed_name)] = tuple(str(shard_name) for shard_name in shard_names) - - pairs: list[tuple[str, str]] = [] - for packed_name, shard_names in mapping.items(): - pairs.extend((packed_name, shard_name) for shard_name in shard_names) - return tuple(pairs) + def _get_modelopt_packed_modules_mapping(self, model: nn.Module) -> dict[str, tuple[str, ...]]: + mapping = { + packed_name: tuple(shard_names) for packed_name, shard_names in MODEL_OPT_PACKED_MODULES_MAPPING.items() + } + mapping.update( + { + str(packed_name): tuple(str(shard_name) for shard_name in shard_names) + for packed_name, shard_names in get_packed_modules_mapping(model).items() + } + ) + return mapping def _resolve_modelopt_target_name( self, name: str, loadable_names: set[str], - packed_name_pairs: tuple[tuple[str, str], ...], + packed_modules_mapping: dict[str, tuple[str, ...]], ) -> str | None: if name in loadable_names: return name @@ -310,10 +297,14 @@ def _resolve_modelopt_target_name( if candidate in loadable_names: return candidate - for packed_name, shard_name in packed_name_pairs: - candidate = self._replace_module_name(name, shard_name, packed_name) - if candidate != name and candidate in loadable_names: - return candidate + for packed_name, shard_names in packed_modules_mapping.items(): + for shard_name in shard_names: + if name.startswith(f"{shard_name}."): + candidate = f"{packed_name}.{name[len(shard_name) + 1 :]}" + else: + candidate = name.replace(f".{shard_name}.", f".{packed_name}.") + if candidate != name and candidate in loadable_names: + return candidate return None @staticmethod @@ -353,50 +344,62 @@ def _dequantize_modelopt_fp8_weight( scale = self._reshape_modelopt_weight_scale(scale, loaded_weight.shape) return (weight * scale).to(dtype=target_dtype) - def _collect_modelopt_scale_tensors( - self, - hf_weights_files: list[str], - prefix: str, - ) -> dict[str, torch.Tensor]: - scale_tensors: dict[str, torch.Tensor] = {} - for filename in sorted(hf_weights_files, key=_natural_sort_key): - if not filename.endswith(".safetensors"): - continue - with safe_open(filename, framework="pt", device="cpu") as f: - for name in f.keys(): - if self._is_modelopt_scale(name): - scale_tensors[prefix + name] = f.get_tensor(name) - return scale_tensors - def _adapt_modelopt_fp8_weights( self, model: nn.Module, source: "ComponentSource", weights: Iterable[tuple[str, torch.Tensor]], - scale_tensors: dict[str, torch.Tensor], ) -> Generator[tuple[str, torch.Tensor], None, None]: loadable_tensors = self._get_model_loadable_tensors(model) loadable_names = set(loadable_tensors) - packed_name_pairs = self._get_modelopt_packed_name_pairs(model) + packed_modules_mapping = self._get_modelopt_packed_modules_mapping(model) + scale_tensors: dict[str, torch.Tensor] = {} + pending_weights: dict[str, list[tuple[str, torch.Tensor, torch.dtype]]] = {} skipped_scales = 0 dequantized_weights = 0 for name, tensor in weights: - target_name = self._resolve_modelopt_target_name(name, loadable_names, packed_name_pairs) + target_name = self._resolve_modelopt_target_name(name, loadable_names, packed_modules_mapping) if self._is_modelopt_scale(name): + scale_tensors[name] = tensor if target_name is None: skipped_scales += 1 - continue - yield name, tensor + else: + yield name, tensor + + for weight_name, weight_tensor, target_dtype in pending_weights.pop(name, []): + yield ( + weight_name, + self._dequantize_modelopt_fp8_weight( + weight_name, + weight_tensor, + scale_tensors, + target_dtype, + ), + ) + dequantized_weights += 1 continue if self._is_fp8_tensor(tensor) and target_name is not None: target_tensor = loadable_tensors[target_name] if target_tensor.dtype not in FP8_DTYPES: - tensor = self._dequantize_modelopt_fp8_weight(name, tensor, scale_tensors, target_tensor.dtype) - dequantized_weights += 1 + scale_name = self._get_weight_scale_name(name) + if scale_name is None: + raise ValueError(f"Missing ModelOpt FP8 weight_scale name for weight {name!r}") + if scale_name in scale_tensors: + tensor = self._dequantize_modelopt_fp8_weight(name, tensor, scale_tensors, target_tensor.dtype) + dequantized_weights += 1 + else: + pending_weights.setdefault(scale_name, []).append((name, tensor, target_tensor.dtype)) + continue yield name, tensor + if pending_weights: + missing_scale_names = ", ".join(repr(name) for name in sorted(pending_weights)) + raise ValueError( + f"Missing ModelOpt FP8 weight_scale for full-precision target weights: {missing_scale_names}" + ) + if skipped_scales or dequantized_weights: logger.info_once( "Adapted ModelOpt FP8 %s weights: dequantized %d full-precision weights, skipped %d scale tensors", From f79a574cab49149aa4d2eaa19405cc6609b1a7dc Mon Sep 17 00:00:00 2001 From: roG0d Date: Sat, 18 Apr 2026 03:35:02 +0000 Subject: [PATCH 04/24] refactoring Signed-off-by: roG0d --- .../model_loader/test_diffusers_loader.py | 159 ------------- .../model_loader/test_modelopt_fp8_adapter.py | 149 +++++++++++++ .../checkpoint_adapters/__init__.py | 22 ++ .../checkpoint_adapters/modelopt_fp8.py | 210 ++++++++++++++++++ .../model_loader/diffusers_loader.py | 199 ++--------------- 5 files changed, 397 insertions(+), 342 deletions(-) create mode 100644 tests/diffusion/model_loader/test_modelopt_fp8_adapter.py create mode 100644 vllm_omni/diffusion/model_loader/checkpoint_adapters/__init__.py create mode 100644 vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py diff --git a/tests/diffusion/model_loader/test_diffusers_loader.py b/tests/diffusion/model_loader/test_diffusers_loader.py index b94b663ed5e..931b79b2a30 100644 --- a/tests/diffusion/model_loader/test_diffusers_loader.py +++ b/tests/diffusion/model_loader/test_diffusers_loader.py @@ -127,162 +127,3 @@ def test_loader_auto_detects_quant_config_from_transformer_config(): DiffusersPipelineLoader._auto_detect_quant_config(od_config) assert od_config.quantization_config is od_config.tf_model_config.quant_config - - -class _PackedModelOptModel(nn.Module): - def __init__(self) -> None: - super().__init__() - self.transformer = nn.Module() - self.transformer.block = nn.Module() - self.transformer.block.to_qkv = nn.Linear(2, 2, bias=False) - - -def test_modelopt_adapter_dequantizes_fp8_weight_for_full_precision_target(): - loader = object.__new__(DiffusersPipelineLoader) - model = _PackedModelOptModel() - source = DiffusersPipelineLoader.ComponentSource( - model_or_path="dummy", - subfolder="transformer", - revision=None, - prefix="transformer.", - ) - fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor([0.5], dtype=torch.float32) - - adapted = list( - loader._adapt_modelopt_fp8_weights( - model, - source, - iter( - [ - ("transformer.block.to_q.weight_scale", scale), - ("transformer.block.to_q.input_scale", torch.tensor([1.0])), - ("transformer.block.to_q.weight", fp8_weight), - ] - ), - ) - ) - - assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"] - assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype - assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) - - -def test_modelopt_adapter_dequantizes_fp8_weight_when_scale_arrives_late(): - loader = object.__new__(DiffusersPipelineLoader) - model = _PackedModelOptModel() - source = DiffusersPipelineLoader.ComponentSource( - model_or_path="dummy", - subfolder="transformer", - revision=None, - prefix="transformer.", - ) - fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor([0.5], dtype=torch.float32) - - adapted = list( - loader._adapt_modelopt_fp8_weights( - model, - source, - iter( - [ - ("transformer.block.to_q.weight", fp8_weight), - ("transformer.block.to_q.weight_scale", scale), - ("transformer.block.to_q.input_scale", torch.tensor([1.0])), - ] - ), - ) - ) - - assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"] - assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype - assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) - - -class _QuantizedPackedModelOptModel(nn.Module): - def __init__(self) -> None: - super().__init__() - self.transformer = nn.Module() - self.transformer.block = nn.Module() - self.transformer.block.to_qkv = nn.Module() - self.transformer.block.to_qkv.register_parameter( - "weight", - nn.Parameter(torch.empty(2, 2, dtype=torch.float8_e4m3fn), requires_grad=False), - ) - self.transformer.block.to_qkv.register_parameter( - "weight_scale", - nn.Parameter(torch.empty(1), requires_grad=False), - ) - self.transformer.block.to_qkv.register_parameter( - "input_scale", - nn.Parameter(torch.empty(1), requires_grad=False), - ) - - -class _ChildPackedModelOptModel(nn.Module): - def __init__(self) -> None: - super().__init__() - self.transformer = nn.Module() - self.transformer.packed_modules_mapping = {"packed_proj": ["proj_a", "proj_b"]} - self.transformer.block = nn.Module() - self.transformer.block.packed_proj = nn.Linear(2, 2, bias=False) - - -def test_modelopt_adapter_uses_child_packed_modules_mapping(): - loader = object.__new__(DiffusersPipelineLoader) - model = _ChildPackedModelOptModel() - source = DiffusersPipelineLoader.ComponentSource( - model_or_path="dummy", - subfolder="transformer", - revision=None, - prefix="transformer.", - ) - fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor([0.5], dtype=torch.float32) - - adapted = list( - loader._adapt_modelopt_fp8_weights( - model, - source, - iter( - [ - ("transformer.block.proj_a.weight", fp8_weight), - ("transformer.block.proj_a.weight_scale", scale), - ] - ), - ) - ) - - assert [name for name, _ in adapted] == ["transformer.block.proj_a.weight"] - assert adapted[0][1].dtype == model.transformer.block.packed_proj.weight.dtype - assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) - - -def test_modelopt_adapter_keeps_scale_tensors_for_quantized_target(): - loader = object.__new__(DiffusersPipelineLoader) - model = _QuantizedPackedModelOptModel() - source = DiffusersPipelineLoader.ComponentSource( - model_or_path="dummy", - subfolder="transformer", - revision=None, - prefix="transformer.", - ) - scale = torch.tensor([0.5], dtype=torch.float32) - - adapted = list( - loader._adapt_modelopt_fp8_weights( - model, - source, - iter( - [ - ("transformer.block.to_q.weight_scale", scale), - ("transformer.block.to_q.input_scale", torch.tensor([1.0])), - ] - ), - ) - ) - - assert [name for name, _ in adapted] == [ - "transformer.block.to_q.weight_scale", - "transformer.block.to_q.input_scale", - ] diff --git a/tests/diffusion/model_loader/test_modelopt_fp8_adapter.py b/tests/diffusion/model_loader/test_modelopt_fp8_adapter.py new file mode 100644 index 00000000000..728afc59295 --- /dev/null +++ b/tests/diffusion/model_loader/test_modelopt_fp8_adapter.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.diffusion.model_loader.checkpoint_adapters import ( + ModelOptFp8CheckpointAdapter, +) +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +class _PackedModelOptModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.transformer = nn.Module() + self.transformer.block = nn.Module() + self.transformer.block.to_qkv = nn.Linear(2, 2, bias=False) + + +class _QuantizedPackedModelOptModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.transformer = nn.Module() + self.transformer.block = nn.Module() + self.transformer.block.to_qkv = nn.Module() + self.transformer.block.to_qkv.register_parameter( + "weight", + nn.Parameter(torch.empty(2, 2, dtype=torch.float8_e4m3fn), requires_grad=False), + ) + self.transformer.block.to_qkv.register_parameter( + "weight_scale", + nn.Parameter(torch.empty(1), requires_grad=False), + ) + self.transformer.block.to_qkv.register_parameter( + "input_scale", + nn.Parameter(torch.empty(1), requires_grad=False), + ) + + +class _ChildPackedModelOptModel(nn.Module): + def __init__(self) -> None: + super().__init__() + self.transformer = nn.Module() + self.transformer.packed_modules_mapping = {"packed_proj": ["proj_a", "proj_b"]} + self.transformer.block = nn.Module() + self.transformer.block.packed_proj = nn.Linear(2, 2, bias=False) + + +def _make_source() -> DiffusersPipelineLoader.ComponentSource: + return DiffusersPipelineLoader.ComponentSource( + model_or_path="dummy", + subfolder="transformer", + revision=None, + prefix="transformer.", + ) + + +def test_modelopt_adapter_dequantizes_fp8_weight_for_full_precision_target(): + model = _PackedModelOptModel() + adapter = ModelOptFp8CheckpointAdapter(model, _make_source()) + fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + adapter.adapt( + iter( + [ + ("transformer.block.to_q.weight_scale", scale), + ("transformer.block.to_q.input_scale", torch.tensor([1.0])), + ("transformer.block.to_q.weight", fp8_weight), + ] + ) + ) + ) + + assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"] + assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype + assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) + + +def test_modelopt_adapter_dequantizes_fp8_weight_when_scale_arrives_late(): + model = _PackedModelOptModel() + adapter = ModelOptFp8CheckpointAdapter(model, _make_source()) + fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + adapter.adapt( + iter( + [ + ("transformer.block.to_q.weight", fp8_weight), + ("transformer.block.to_q.weight_scale", scale), + ("transformer.block.to_q.input_scale", torch.tensor([1.0])), + ] + ) + ) + ) + + assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"] + assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype + assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) + + +def test_modelopt_adapter_uses_child_packed_modules_mapping(): + model = _ChildPackedModelOptModel() + adapter = ModelOptFp8CheckpointAdapter(model, _make_source()) + fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + adapter.adapt( + iter( + [ + ("transformer.block.proj_a.weight", fp8_weight), + ("transformer.block.proj_a.weight_scale", scale), + ] + ) + ) + ) + + assert [name for name, _ in adapted] == ["transformer.block.proj_a.weight"] + assert adapted[0][1].dtype == model.transformer.block.packed_proj.weight.dtype + assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) + + +def test_modelopt_adapter_keeps_scale_tensors_for_quantized_target(): + model = _QuantizedPackedModelOptModel() + adapter = ModelOptFp8CheckpointAdapter(model, _make_source()) + scale = torch.tensor([0.5], dtype=torch.float32) + + adapted = list( + adapter.adapt( + iter( + [ + ("transformer.block.to_q.weight_scale", scale), + ("transformer.block.to_q.input_scale", torch.tensor([1.0])), + ] + ) + ) + ) + + assert [name for name, _ in adapted] == [ + "transformer.block.to_q.weight_scale", + "transformer.block.to_q.input_scale", + ] diff --git a/vllm_omni/diffusion/model_loader/checkpoint_adapters/__init__.py b/vllm_omni/diffusion/model_loader/checkpoint_adapters/__init__.py new file mode 100644 index 00000000000..569fed8da58 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/checkpoint_adapters/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from torch import nn + +from .modelopt_fp8 import ModelOptFp8CheckpointAdapter + + +def get_checkpoint_adapter( + model: nn.Module, + source: object, + quant_config: object | None, + use_safetensors: bool, +) -> ModelOptFp8CheckpointAdapter | None: + if ModelOptFp8CheckpointAdapter.is_compatible(source, quant_config, use_safetensors): + return ModelOptFp8CheckpointAdapter(model, source) + return None + + +__all__ = [ + "ModelOptFp8CheckpointAdapter", + "get_checkpoint_adapter", +] diff --git a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py new file mode 100644 index 00000000000..92bc5d1a41c --- /dev/null +++ b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Generator, Iterable + +import torch +from torch import nn +from vllm.logger import init_logger +from vllm.model_executor.utils import get_packed_modules_mapping + +logger = init_logger(__name__) + +MODEL_OPT_SCALE_SUFFIXES = (".input_scale", ".weight_scale", ".weight_scale_inv") +DEFAULT_PACKED_MODULES_MAPPING = { + "to_qkv": ("to_q", "to_k", "to_v"), + "add_kv_proj": ("add_q_proj", "add_k_proj", "add_v_proj"), + "w13": ("w1", "w3"), +} +FP8_DTYPES = tuple( + dtype + for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2fnuz", None), + ) + if dtype is not None +) + + +class ModelOptFp8CheckpointAdapter: + def __init__(self, model: nn.Module, source: object): + self._loadable_tensors = self._get_model_loadable_tensors(model) + self._loadable_names = set(self._loadable_tensors) + self._packed_modules_mapping = self._get_packed_modules_mapping(model) + self._source_label = getattr(source, "prefix", "") or getattr(source, "subfolder", None) or "model" + + @classmethod + def is_compatible( + cls, + source: object, + quant_config: object | None, + use_safetensors: bool, + ) -> bool: + return use_safetensors and cls._is_transformer_source(source) and cls._is_checkpoint_quant_config(quant_config) + + @staticmethod + def _is_transformer_source(source: object) -> bool: + if getattr(source, "subfolder", None) == "transformer": + return True + return str(getattr(source, "prefix", "")).startswith("transformer.") + + @staticmethod + def _is_checkpoint_quant_config(quant_config: object | None) -> bool: + return ( + quant_config is not None + and hasattr(quant_config, "get_name") + and quant_config.get_name() == "modelopt" + and bool(getattr(quant_config, "is_checkpoint_fp8_serialized", False)) + ) + + @staticmethod + def _get_model_loadable_tensors(model: nn.Module) -> dict[str, torch.Tensor]: + loadable_tensors: dict[str, torch.Tensor] = {name: param for name, param in model.named_parameters()} + loadable_tensors.update({name: buffer for name, buffer in model.named_buffers()}) + return loadable_tensors + + @staticmethod + def _is_scale(name: str) -> bool: + return name.endswith(MODEL_OPT_SCALE_SUFFIXES) + + @staticmethod + def _is_fp8_tensor(tensor: torch.Tensor) -> bool: + return tensor.dtype in FP8_DTYPES + + @staticmethod + def _get_weight_scale_name(weight_name: str) -> str | None: + if weight_name.endswith(".weight"): + return weight_name[: -len(".weight")] + ".weight_scale" + return None + + @staticmethod + def _replace_module_name(name: str, old: str, new: str) -> str: + if name.startswith(f"{old}."): + return f"{new}.{name[len(old) + 1 :]}" + return name.replace(f".{old}.", f".{new}.") + + @staticmethod + def _get_packed_modules_mapping(model: nn.Module) -> dict[str, tuple[str, ...]]: + mapping = { + packed_name: tuple(shard_names) for packed_name, shard_names in DEFAULT_PACKED_MODULES_MAPPING.items() + } + mapping.update( + { + str(packed_name): tuple(str(shard_name) for shard_name in shard_names) + for packed_name, shard_names in get_packed_modules_mapping(model).items() + } + ) + return mapping + + def _resolve_target_name(self, name: str) -> str | None: + if name in self._loadable_names: + return name + + if ".to_out.0." in name: + candidate = name.replace(".to_out.0.", ".to_out.") + if candidate in self._loadable_names: + return candidate + + for packed_name, shard_names in self._packed_modules_mapping.items(): + for shard_name in shard_names: + candidate = self._replace_module_name(name, shard_name, packed_name) + if candidate != name and candidate in self._loadable_names: + return candidate + return None + + @staticmethod + def _reshape_weight_scale(scale: torch.Tensor, weight_shape: torch.Size) -> torch.Tensor: + if scale.numel() == 1: + return scale.reshape(()) + if len(weight_shape) == 2 and scale.ndim == 1 and scale.shape[0] == weight_shape[0]: + return scale.reshape(-1, 1) + if tuple(scale.shape) == tuple(weight_shape): + return scale + if ( + len(weight_shape) == 2 + and scale.ndim == 4 + and scale.shape[1] == 1 + and scale.shape[3] == 1 + and weight_shape[0] % scale.shape[0] == 0 + and weight_shape[1] % scale.shape[2] == 0 + ): + block_n = weight_shape[0] // scale.shape[0] + block_k = weight_shape[1] // scale.shape[2] + return scale.expand(scale.shape[0], block_n, scale.shape[2], block_k).reshape(weight_shape) + raise ValueError(f"Unsupported ModelOpt FP8 weight_scale shape {tuple(scale.shape)} for weight {weight_shape}") + + def _dequantize_weight( + self, + name: str, + loaded_weight: torch.Tensor, + scale_tensors: dict[str, torch.Tensor], + target_dtype: torch.dtype, + ) -> torch.Tensor: + scale_name = self._get_weight_scale_name(name) + if scale_name is None or scale_name not in scale_tensors: + raise ValueError(f"Missing ModelOpt FP8 weight_scale for full-precision target weight {name!r}") + + weight = loaded_weight.to(dtype=torch.float32) + scale = scale_tensors[scale_name].to(dtype=torch.float32, device=weight.device) + scale = self._reshape_weight_scale(scale, loaded_weight.shape) + return (weight * scale).to(dtype=target_dtype) + + def adapt( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ) -> Generator[tuple[str, torch.Tensor], None, None]: + scale_tensors: dict[str, torch.Tensor] = {} + pending_weights: dict[str, list[tuple[str, torch.Tensor, torch.dtype]]] = {} + + skipped_scales = 0 + dequantized_weights = 0 + for name, tensor in weights: + target_name = self._resolve_target_name(name) + if self._is_scale(name): + scale_tensors[name] = tensor + if target_name is None: + skipped_scales += 1 + else: + yield name, tensor + + for weight_name, weight_tensor, target_dtype in pending_weights.pop(name, []): + yield ( + weight_name, + self._dequantize_weight( + weight_name, + weight_tensor, + scale_tensors, + target_dtype, + ), + ) + dequantized_weights += 1 + continue + + if self._is_fp8_tensor(tensor) and target_name is not None: + target_tensor = self._loadable_tensors[target_name] + if target_tensor.dtype not in FP8_DTYPES: + scale_name = self._get_weight_scale_name(name) + if scale_name is None: + raise ValueError(f"Missing ModelOpt FP8 weight_scale name for weight {name!r}") + if scale_name in scale_tensors: + tensor = self._dequantize_weight(name, tensor, scale_tensors, target_tensor.dtype) + dequantized_weights += 1 + else: + pending_weights.setdefault(scale_name, []).append((name, tensor, target_tensor.dtype)) + continue + yield name, tensor + + if pending_weights: + missing_scale_names = ", ".join(repr(name) for name in sorted(pending_weights)) + raise ValueError( + f"Missing ModelOpt FP8 weight_scale for full-precision target weights: {missing_scale_names}" + ) + + if skipped_scales or dequantized_weights: + logger.info_once( + "Adapted ModelOpt FP8 %s weights: dequantized %d full-precision weights, skipped %d scale tensors", + self._source_label, + dequantized_weights, + skipped_scales, + ) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 1f5daacbcd0..09ec82ebcfd 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -26,12 +26,14 @@ multi_thread_safetensors_weights_iterator, safetensors_weights_iterator, ) -from vllm.model_executor.utils import get_packed_modules_mapping from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import set_default_torch_dtype from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.distributed.hsdp import HSDPInferenceConfig +from vllm_omni.diffusion.model_loader.checkpoint_adapters import ( + get_checkpoint_adapter, +) from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter from vllm_omni.diffusion.registry import initialize_model @@ -46,22 +48,6 @@ def _natural_sort_key(filepath: str) -> list: MODEL_INDEX = "model_index.json" DIFFUSION_MODEL_WEIGHTS_INDEX = "diffusion_pytorch_model.safetensors.index.json" -MODEL_OPT_SCALE_SUFFIXES = (".input_scale", ".weight_scale", ".weight_scale_inv") -MODEL_OPT_PACKED_MODULES_MAPPING = { - "to_qkv": ("to_q", "to_k", "to_v"), - "add_kv_proj": ("add_q_proj", "add_k_proj", "add_v_proj"), - "w13": ("w1", "w3"), -} -FP8_DTYPES = tuple( - dtype - for dtype in ( - getattr(torch, "float8_e4m3fn", None), - getattr(torch, "float8_e5m2", None), - getattr(torch, "float8_e4m3fnuz", None), - getattr(torch, "float8_e5m2fnuz", None), - ) - if dtype is not None -) class DiffusersPipelineLoader: @@ -227,8 +213,10 @@ def _get_weights_iterator( self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. prefixed_weights_iterator = ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) - if model is not None and self._should_adapt_modelopt_fp8_weights(source, use_safetensors): - return self._adapt_modelopt_fp8_weights(model, source, prefixed_weights_iterator) + if model is not None: + checkpoint_adapter = self._get_checkpoint_adapter(model, source, use_safetensors) + if checkpoint_adapter is not None: + return checkpoint_adapter.adapt(prefixed_weights_iterator) return prefixed_weights_iterator def _get_source_quant_config(self, source: "ComponentSource") -> object | None: @@ -240,173 +228,18 @@ def _get_source_quant_config(self, source: "ComponentSource") -> object | None: return quant_config.resolve(source.prefix.rstrip(".")) return quant_config - @staticmethod - def _is_modelopt_fp8_checkpoint_quant_config(quant_config: object) -> bool: - return ( - hasattr(quant_config, "get_name") - and quant_config.get_name() == "modelopt" - and bool(getattr(quant_config, "is_checkpoint_fp8_serialized", False)) - ) - - def _should_adapt_modelopt_fp8_weights(self, source: "ComponentSource", use_safetensors: bool) -> bool: - if not use_safetensors or not self._is_transformer_source(source): - return False - - quant_config = self._get_source_quant_config(source) - if quant_config is None: - return False - return self._is_modelopt_fp8_checkpoint_quant_config(quant_config) - - @staticmethod - def _is_modelopt_scale(name: str) -> bool: - return name.endswith(MODEL_OPT_SCALE_SUFFIXES) - - @staticmethod - def _is_fp8_tensor(tensor: torch.Tensor) -> bool: - return tensor.dtype in FP8_DTYPES - - @staticmethod - def _get_weight_scale_name(weight_name: str) -> str | None: - if weight_name.endswith(".weight"): - return weight_name[: -len(".weight")] + ".weight_scale" - return None - - def _get_modelopt_packed_modules_mapping(self, model: nn.Module) -> dict[str, tuple[str, ...]]: - mapping = { - packed_name: tuple(shard_names) for packed_name, shard_names in MODEL_OPT_PACKED_MODULES_MAPPING.items() - } - mapping.update( - { - str(packed_name): tuple(str(shard_name) for shard_name in shard_names) - for packed_name, shard_names in get_packed_modules_mapping(model).items() - } - ) - return mapping - - def _resolve_modelopt_target_name( - self, - name: str, - loadable_names: set[str], - packed_modules_mapping: dict[str, tuple[str, ...]], - ) -> str | None: - if name in loadable_names: - return name - - if ".to_out.0." in name: - candidate = name.replace(".to_out.0.", ".to_out.") - if candidate in loadable_names: - return candidate - - for packed_name, shard_names in packed_modules_mapping.items(): - for shard_name in shard_names: - if name.startswith(f"{shard_name}."): - candidate = f"{packed_name}.{name[len(shard_name) + 1 :]}" - else: - candidate = name.replace(f".{shard_name}.", f".{packed_name}.") - if candidate != name and candidate in loadable_names: - return candidate - return None - - @staticmethod - def _reshape_modelopt_weight_scale(scale: torch.Tensor, weight_shape: torch.Size) -> torch.Tensor: - if scale.numel() == 1: - return scale.reshape(()) - if len(weight_shape) == 2 and scale.ndim == 1 and scale.shape[0] == weight_shape[0]: - return scale.reshape(-1, 1) - if tuple(scale.shape) == tuple(weight_shape): - return scale - if ( - len(weight_shape) == 2 - and scale.ndim == 4 - and scale.shape[1] == 1 - and scale.shape[3] == 1 - and weight_shape[0] % scale.shape[0] == 0 - and weight_shape[1] % scale.shape[2] == 0 - ): - block_n = weight_shape[0] // scale.shape[0] - block_k = weight_shape[1] // scale.shape[2] - return scale.expand(scale.shape[0], block_n, scale.shape[2], block_k).reshape(weight_shape) - raise ValueError(f"Unsupported ModelOpt FP8 weight_scale shape {tuple(scale.shape)} for weight {weight_shape}") - - def _dequantize_modelopt_fp8_weight( - self, - name: str, - loaded_weight: torch.Tensor, - scale_tensors: dict[str, torch.Tensor], - target_dtype: torch.dtype, - ) -> torch.Tensor: - scale_name = self._get_weight_scale_name(name) - if scale_name is None or scale_name not in scale_tensors: - raise ValueError(f"Missing ModelOpt FP8 weight_scale for full-precision target weight {name!r}") - - weight = loaded_weight.to(dtype=torch.float32) - scale = scale_tensors[scale_name].to(dtype=torch.float32, device=weight.device) - scale = self._reshape_modelopt_weight_scale(scale, loaded_weight.shape) - return (weight * scale).to(dtype=target_dtype) - - def _adapt_modelopt_fp8_weights( + def _get_checkpoint_adapter( self, model: nn.Module, source: "ComponentSource", - weights: Iterable[tuple[str, torch.Tensor]], - ) -> Generator[tuple[str, torch.Tensor], None, None]: - loadable_tensors = self._get_model_loadable_tensors(model) - loadable_names = set(loadable_tensors) - packed_modules_mapping = self._get_modelopt_packed_modules_mapping(model) - scale_tensors: dict[str, torch.Tensor] = {} - pending_weights: dict[str, list[tuple[str, torch.Tensor, torch.dtype]]] = {} - - skipped_scales = 0 - dequantized_weights = 0 - for name, tensor in weights: - target_name = self._resolve_modelopt_target_name(name, loadable_names, packed_modules_mapping) - if self._is_modelopt_scale(name): - scale_tensors[name] = tensor - if target_name is None: - skipped_scales += 1 - else: - yield name, tensor - - for weight_name, weight_tensor, target_dtype in pending_weights.pop(name, []): - yield ( - weight_name, - self._dequantize_modelopt_fp8_weight( - weight_name, - weight_tensor, - scale_tensors, - target_dtype, - ), - ) - dequantized_weights += 1 - continue - - if self._is_fp8_tensor(tensor) and target_name is not None: - target_tensor = loadable_tensors[target_name] - if target_tensor.dtype not in FP8_DTYPES: - scale_name = self._get_weight_scale_name(name) - if scale_name is None: - raise ValueError(f"Missing ModelOpt FP8 weight_scale name for weight {name!r}") - if scale_name in scale_tensors: - tensor = self._dequantize_modelopt_fp8_weight(name, tensor, scale_tensors, target_tensor.dtype) - dequantized_weights += 1 - else: - pending_weights.setdefault(scale_name, []).append((name, tensor, target_tensor.dtype)) - continue - yield name, tensor - - if pending_weights: - missing_scale_names = ", ".join(repr(name) for name in sorted(pending_weights)) - raise ValueError( - f"Missing ModelOpt FP8 weight_scale for full-precision target weights: {missing_scale_names}" - ) - - if skipped_scales or dequantized_weights: - logger.info_once( - "Adapted ModelOpt FP8 %s weights: dequantized %d full-precision weights, skipped %d scale tensors", - source.prefix or source.subfolder or "model", - dequantized_weights, - skipped_scales, - ) + use_safetensors: bool, + ): + return get_checkpoint_adapter( + model=model, + source=source, + quant_config=self._get_source_quant_config(source), + use_safetensors=use_safetensors, + ) def get_all_weights( self, From e398ab32d69aa70951db10b0c5bd9bf28469970d Mon Sep 17 00:00:00 2001 From: roG0d Date: Sat, 18 Apr 2026 04:11:55 +0000 Subject: [PATCH 05/24] continue refacoring Signed-off-by: roG0d --- .../model_loader/test_modelopt_fp8_adapter.py | 63 +----- .../diffusion/quantization/test_fp8_config.py | 19 -- .../checkpoint_adapters/modelopt_fp8.py | 182 +++++++++++------- 3 files changed, 116 insertions(+), 148 deletions(-) diff --git a/tests/diffusion/model_loader/test_modelopt_fp8_adapter.py b/tests/diffusion/model_loader/test_modelopt_fp8_adapter.py index 728afc59295..513e80d2887 100644 --- a/tests/diffusion/model_loader/test_modelopt_fp8_adapter.py +++ b/tests/diffusion/model_loader/test_modelopt_fp8_adapter.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import pytest import torch import torch.nn as nn @@ -8,7 +10,6 @@ from vllm_omni.diffusion.model_loader.checkpoint_adapters import ( ModelOptFp8CheckpointAdapter, ) -from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] @@ -41,20 +42,9 @@ def __init__(self) -> None: ) -class _ChildPackedModelOptModel(nn.Module): - def __init__(self) -> None: - super().__init__() - self.transformer = nn.Module() - self.transformer.packed_modules_mapping = {"packed_proj": ["proj_a", "proj_b"]} - self.transformer.block = nn.Module() - self.transformer.block.packed_proj = nn.Linear(2, 2, bias=False) - - -def _make_source() -> DiffusersPipelineLoader.ComponentSource: - return DiffusersPipelineLoader.ComponentSource( - model_or_path="dummy", +def _make_source() -> SimpleNamespace: + return SimpleNamespace( subfolder="transformer", - revision=None, prefix="transformer.", ) @@ -82,51 +72,6 @@ def test_modelopt_adapter_dequantizes_fp8_weight_for_full_precision_target(): assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) -def test_modelopt_adapter_dequantizes_fp8_weight_when_scale_arrives_late(): - model = _PackedModelOptModel() - adapter = ModelOptFp8CheckpointAdapter(model, _make_source()) - fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor([0.5], dtype=torch.float32) - - adapted = list( - adapter.adapt( - iter( - [ - ("transformer.block.to_q.weight", fp8_weight), - ("transformer.block.to_q.weight_scale", scale), - ("transformer.block.to_q.input_scale", torch.tensor([1.0])), - ] - ) - ) - ) - - assert [name for name, _ in adapted] == ["transformer.block.to_q.weight"] - assert adapted[0][1].dtype == model.transformer.block.to_qkv.weight.dtype - assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) - - -def test_modelopt_adapter_uses_child_packed_modules_mapping(): - model = _ChildPackedModelOptModel() - adapter = ModelOptFp8CheckpointAdapter(model, _make_source()) - fp8_weight = torch.tensor([[2.0, -4.0], [1.0, 3.0]], dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor([0.5], dtype=torch.float32) - - adapted = list( - adapter.adapt( - iter( - [ - ("transformer.block.proj_a.weight", fp8_weight), - ("transformer.block.proj_a.weight_scale", scale), - ] - ) - ) - ) - - assert [name for name, _ in adapted] == ["transformer.block.proj_a.weight"] - assert adapted[0][1].dtype == model.transformer.block.packed_proj.weight.dtype - assert torch.allclose(adapted[0][1], fp8_weight.to(torch.float32) * scale) - - def test_modelopt_adapter_keeps_scale_tensors_for_quantized_target(): model = _QuantizedPackedModelOptModel() adapter = ModelOptFp8CheckpointAdapter(model, _make_source()) diff --git a/tests/diffusion/quantization/test_fp8_config.py b/tests/diffusion/quantization/test_fp8_config.py index a85c4925358..0a4117d0d78 100644 --- a/tests/diffusion/quantization/test_fp8_config.py +++ b/tests/diffusion/quantization/test_fp8_config.py @@ -73,25 +73,6 @@ def test_build_quant_config_modelopt_fp8_config_json(): assert config.is_checkpoint_fp8_serialized -def test_build_quant_config_modelopt_nested_checkpoint_metadata(): - from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config - - from vllm_omni.quantization import build_quant_config - - config = build_quant_config( - { - "producer": {"name": "modelopt"}, - "quantization": { - "quant_algo": "FP8", - "exclude_modules": ["proj_out"], - }, - } - ) - - assert isinstance(config, ModelOptFp8Config) - assert config.get_name() == "modelopt" - - def test_build_quant_config_per_component(): from vllm_omni.quantization import ComponentQuantizationConfig, build_quant_config diff --git a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py index 92bc5d1a41c..64661a15f29 100644 --- a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py +++ b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Generator, Iterable +from dataclasses import dataclass, field import torch from torch import nn from vllm.logger import init_logger +from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import get_packed_modules_mapping logger = init_logger(__name__) @@ -27,11 +29,18 @@ ) +@dataclass +class _AdaptState: + scale_tensors: dict[str, torch.Tensor] = field(default_factory=dict) + pending_weights: dict[str, list[tuple[str, torch.Tensor, torch.dtype]]] = field(default_factory=dict) + skipped_scales: int = 0 + dequantized_weights: int = 0 + + class ModelOptFp8CheckpointAdapter: def __init__(self, model: nn.Module, source: object): self._loadable_tensors = self._get_model_loadable_tensors(model) - self._loadable_names = set(self._loadable_tensors) - self._packed_modules_mapping = self._get_packed_modules_mapping(model) + self._weights_mapper = self._get_weights_mapper(model) self._source_label = getattr(source, "prefix", "") or getattr(source, "subfolder", None) or "model" @classmethod @@ -78,14 +87,8 @@ def _get_weight_scale_name(weight_name: str) -> str | None: return weight_name[: -len(".weight")] + ".weight_scale" return None - @staticmethod - def _replace_module_name(name: str, old: str, new: str) -> str: - if name.startswith(f"{old}."): - return f"{new}.{name[len(old) + 1 :]}" - return name.replace(f".{old}.", f".{new}.") - - @staticmethod - def _get_packed_modules_mapping(model: nn.Module) -> dict[str, tuple[str, ...]]: + @classmethod + def _get_weights_mapper(cls, model: nn.Module) -> WeightsMapper: mapping = { packed_name: tuple(shard_names) for packed_name, shard_names in DEFAULT_PACKED_MODULES_MAPPING.items() } @@ -95,22 +98,25 @@ def _get_packed_modules_mapping(model: nn.Module) -> dict[str, tuple[str, ...]]: for packed_name, shard_names in get_packed_modules_mapping(model).items() } ) - return mapping + + orig_to_new_substr = {".to_out.0.": ".to_out."} + orig_to_new_prefix: dict[str, str] = {} + for packed_name, shard_names in mapping.items(): + for shard_name in shard_names: + orig_to_new_substr[f".{shard_name}."] = f".{packed_name}." + orig_to_new_prefix[f"{shard_name}."] = f"{packed_name}." + return WeightsMapper( + orig_to_new_substr=orig_to_new_substr, + orig_to_new_prefix=orig_to_new_prefix, + ) def _resolve_target_name(self, name: str) -> str | None: - if name in self._loadable_names: + if name in self._loadable_tensors: return name - if ".to_out.0." in name: - candidate = name.replace(".to_out.0.", ".to_out.") - if candidate in self._loadable_names: + for candidate in self._weights_mapper.apply_list([name]): + if candidate != name and candidate in self._loadable_tensors: return candidate - - for packed_name, shard_names in self._packed_modules_mapping.items(): - for shard_name in shard_names: - candidate = self._replace_module_name(name, shard_name, packed_name) - if candidate != name and candidate in self._loadable_names: - return candidate return None @staticmethod @@ -138,73 +144,109 @@ def _dequantize_weight( self, name: str, loaded_weight: torch.Tensor, - scale_tensors: dict[str, torch.Tensor], + state: _AdaptState, target_dtype: torch.dtype, ) -> torch.Tensor: scale_name = self._get_weight_scale_name(name) - if scale_name is None or scale_name not in scale_tensors: + if scale_name is None or scale_name not in state.scale_tensors: raise ValueError(f"Missing ModelOpt FP8 weight_scale for full-precision target weight {name!r}") weight = loaded_weight.to(dtype=torch.float32) - scale = scale_tensors[scale_name].to(dtype=torch.float32, device=weight.device) + scale = state.scale_tensors[scale_name].to(dtype=torch.float32, device=weight.device) scale = self._reshape_weight_scale(scale, loaded_weight.shape) return (weight * scale).to(dtype=target_dtype) + def _flush_pending_weights( + self, + scale_name: str, + state: _AdaptState, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + for weight_name, weight_tensor, target_dtype in state.pending_weights.pop(scale_name, []): + yield weight_name, self._dequantize_weight(weight_name, weight_tensor, state, target_dtype) + state.dequantized_weights += 1 + + def _handle_scale_tensor( + self, + name: str, + tensor: torch.Tensor, + target_name: str | None, + state: _AdaptState, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + state.scale_tensors[name] = tensor + if target_name is None: + state.skipped_scales += 1 + else: + yield name, tensor + yield from self._flush_pending_weights(name, state) + + def _target_dtype_for_dequantization( + self, + tensor: torch.Tensor, + target_name: str | None, + ) -> torch.dtype | None: + if target_name is None or not self._is_fp8_tensor(tensor): + return None + + target_dtype = self._loadable_tensors[target_name].dtype + if target_dtype in FP8_DTYPES: + return None + return target_dtype + + def _maybe_dequantize_or_defer_weight( + self, + name: str, + tensor: torch.Tensor, + target_dtype: torch.dtype, + state: _AdaptState, + ) -> torch.Tensor | None: + scale_name = self._get_weight_scale_name(name) + if scale_name is None: + raise ValueError(f"Missing ModelOpt FP8 weight_scale name for weight {name!r}") + + if scale_name not in state.scale_tensors: + state.pending_weights.setdefault(scale_name, []).append((name, tensor, target_dtype)) + return None + + state.dequantized_weights += 1 + return self._dequantize_weight(name, tensor, state, target_dtype) + + @staticmethod + def _check_pending_weights(state: _AdaptState) -> None: + if not state.pending_weights: + return + + missing_scale_names = ", ".join(repr(name) for name in sorted(state.pending_weights)) + raise ValueError(f"Missing ModelOpt FP8 weight_scale for full-precision target weights: {missing_scale_names}") + + def _log_adaptation_summary(self, state: _AdaptState) -> None: + if not state.skipped_scales and not state.dequantized_weights: + return + + logger.info_once( + "Adapted ModelOpt FP8 %s weights: dequantized %d full-precision weights, skipped %d scale tensors", + self._source_label, + state.dequantized_weights, + state.skipped_scales, + ) + def adapt( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> Generator[tuple[str, torch.Tensor], None, None]: - scale_tensors: dict[str, torch.Tensor] = {} - pending_weights: dict[str, list[tuple[str, torch.Tensor, torch.dtype]]] = {} + state = _AdaptState() - skipped_scales = 0 - dequantized_weights = 0 for name, tensor in weights: target_name = self._resolve_target_name(name) if self._is_scale(name): - scale_tensors[name] = tensor - if target_name is None: - skipped_scales += 1 - else: - yield name, tensor - - for weight_name, weight_tensor, target_dtype in pending_weights.pop(name, []): - yield ( - weight_name, - self._dequantize_weight( - weight_name, - weight_tensor, - scale_tensors, - target_dtype, - ), - ) - dequantized_weights += 1 + yield from self._handle_scale_tensor(name, tensor, target_name, state) continue - if self._is_fp8_tensor(tensor) and target_name is not None: - target_tensor = self._loadable_tensors[target_name] - if target_tensor.dtype not in FP8_DTYPES: - scale_name = self._get_weight_scale_name(name) - if scale_name is None: - raise ValueError(f"Missing ModelOpt FP8 weight_scale name for weight {name!r}") - if scale_name in scale_tensors: - tensor = self._dequantize_weight(name, tensor, scale_tensors, target_tensor.dtype) - dequantized_weights += 1 - else: - pending_weights.setdefault(scale_name, []).append((name, tensor, target_tensor.dtype)) - continue + target_dtype = self._target_dtype_for_dequantization(tensor, target_name) + if target_dtype is not None: + tensor = self._maybe_dequantize_or_defer_weight(name, tensor, target_dtype, state) + if tensor is None: + continue yield name, tensor - if pending_weights: - missing_scale_names = ", ".join(repr(name) for name in sorted(pending_weights)) - raise ValueError( - f"Missing ModelOpt FP8 weight_scale for full-precision target weights: {missing_scale_names}" - ) - - if skipped_scales or dequantized_weights: - logger.info_once( - "Adapted ModelOpt FP8 %s weights: dequantized %d full-precision weights, skipped %d scale tensors", - self._source_label, - dequantized_weights, - skipped_scales, - ) + self._check_pending_weights(state) + self._log_adaptation_summary(state) From 8a3b83d98d76f19829737ecce15aed513015aaa3 Mon Sep 17 00:00:00 2001 From: roG0d Date: Sun, 19 Apr 2026 06:45:36 +0000 Subject: [PATCH 06/24] fix huawei Signed-off-by: roG0d --- vllm_omni/diffusion/data.py | 17 +++++- .../hunyuan_image3/hunyuan_fused_moe.py | 2 +- .../hunyuan_image3_transformer.py | 59 +++++++++++++------ 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 211e91d9ca7..2bab09fa594 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -641,13 +641,28 @@ def __post_init__(self): raise ValueError("max_cpu_loras must be >= 1 for diffusion LoRA") def _propagate_quantization_from_tf_config(self, tf_config: "TransformerConfig") -> None: - if self.quantization_config is None and tf_config.quant_config is not None: + if tf_config.quant_config is None: + return + + is_checkpoint_fp8 = bool(getattr(tf_config.quant_config, "is_checkpoint_fp8_serialized", False)) + should_use_checkpoint_config = self.quantization_config is None or ( + is_checkpoint_fp8 and self._is_generic_fp8_quant_config(self.quantization_config) + ) + if should_use_checkpoint_config: self.quantization_config = tf_config.quant_config logger.info( "Auto-detected quantization '%s' from model config", tf_config.quant_method, ) + @staticmethod + def _is_generic_fp8_quant_config(quant_config: object) -> bool: + if isinstance(quant_config, str): + return quant_config.lower() == "fp8" + if hasattr(quant_config, "get_name"): + return quant_config.get_name() == "fp8" + return False + def set_tf_model_config(self, tf_config: "TransformerConfig") -> None: """Assign `tf_model_config` and propagate quantization if detected. diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py index cb44717c533..7a34edbec31 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_fused_moe.py @@ -32,7 +32,7 @@ def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: - if self.quant_method: + if self.quant_method and getattr(self.quant_method, "moe_kernel", None) is None: self.quant_method.process_weights_after_loading(self) self._init_hook_handle.remove() diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py index fbdacddaf34..8f4f27ffe29 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py @@ -1723,7 +1723,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - quant_config=None, + quant_config=quant_config, bias=attention_bias, cache_config=None, prefix=f"{prefix}.self_attn", @@ -2034,6 +2034,44 @@ def contains_unexpected_keyword(name, keywords): return True return False + def is_scalar_quant_scale(name: str, tensor: torch.Tensor) -> bool: + return tensor.dim() == 0 and name.endswith((".input_scale", ".weight_scale")) + + def load_split_param( + name: str, + tensor: torch.Tensor, + den: int, + split_param: list[tuple[str | int, int]], + func: Callable[[torch.Tensor], torch.Tensor] | None, + ) -> None: + param = params_dict[name] + weight_loader = param.weight_loader + if is_scalar_quant_scale(name, tensor): + for shard_id, _ in split_param: + weight_loader(param, tensor, shard_id) + return + + assert tensor.shape[0] % den == 0 + units = tensor.shape[0] // den + offset = 0 + tensor = func(tensor) if func else tensor + for shard_id, num in split_param: + new_offset = offset + num * units + weight_loader(param, tensor[offset:new_offset], shard_id) + offset = new_offset + + def get_loaded_weight_shard( + name: str, + tensor: torch.Tensor, + offset: int, + den: int, + ) -> torch.Tensor: + if is_scalar_quant_scale(name, tensor): + return tensor + assert tensor.shape[0] % den == 0 + units = tensor.shape[0] // den + return tensor[offset * units : offset * units + units] + for name, loaded_weight in weights: # print(f"Loading weight name: {name}, tp_rank: {tp_rank}", flush=True) if contains_unexpected_keyword(name, unexpected_keywords): @@ -2109,19 +2147,7 @@ def contains_unexpected_keyword(name, keywords): if is_pp_missing_parameter(name, self): continue - assert loaded_weight.shape[0] % den == 0 - units = loaded_weight.shape[0] // den - param = params_dict[name] - weight_loader = param.weight_loader - offset = 0 - for shard_id, num in split_param: - new_offset = offset + num * units - if func: - weight_loader(param, func(loaded_weight)[offset:new_offset], shard_id) - else: - weight_loader(param, loaded_weight[offset:new_offset], shard_id) - offset = new_offset - + load_split_param(name, loaded_weight, den, split_param, func) break else: # Skip loading extra bias for GPTQ models. @@ -2151,12 +2177,11 @@ def contains_unexpected_keyword(name, keywords): continue param = params_dict[name_mapped] weight_loader = cast(Callable[..., bool], param.weight_loader) - assert loaded_weight.shape[0] % den == 0 - units = loaded_weight.shape[0] // den + loaded_weight_shard = get_loaded_weight_shard(name, loaded_weight, offset, den) success = weight_loader( param, - loaded_weight[offset * units : offset * units + units], + loaded_weight_shard, name_mapped, shard_id=shard_id, expert_id=expert_id, From b2b15f053339ab2d26ad89c5b64c1ebb34da20f7 Mon Sep 17 00:00:00 2001 From: roG0d Date: Sun, 19 Apr 2026 17:57:48 +0000 Subject: [PATCH 07/24] fix online server problem Signed-off-by: roG0d --- vllm_omni/entrypoints/openai/api_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index 745b719d5b2..cf09a5a6c58 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -63,8 +63,8 @@ from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.serving import ServingClassification from vllm.entrypoints.pooling.embed.serving import ServingEmbedding as OpenAIServingEmbedding -from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling -from vllm.entrypoints.pooling.score.serving import ServingScores +from vllm.entrypoints.pooling.pooling.serving import ServingPooling as OpenAIServingPooling +from vllm.entrypoints.pooling.scoring.serving import ServingScores from vllm.entrypoints.serve.disagg.serving import ServingTokens # vLLM moved `base` from openai.basic.api_router to serve.instrumentator.basic. From fc91876c80ec8a9d1ea11ea6849a5614c4e4b3d2 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 03:04:12 +0800 Subject: [PATCH 08/24] [Quant] Wire quant_config through HunyuanVideo-1.5 DiT (extracted from #2920) Threads quant_config / prefix through HunyuanVideo15Attention, HunyuanVideo15TransformerBlock, and HunyuanVideo15Transformer3DModel so the modelopt FP8 adapter from #2913 has somewhere to bind per-layer scales. Modulation, embeddings, proj_out stay raw nn.Linear (full precision). Signed-off-by: lishunyang --- .../hunyuan_video_15_transformer.py | 44 ++++++++++++++++--- .../pipeline_hunyuan_video_1_5.py | 4 +- .../pipeline_hunyuan_video_1_5_i2v.py | 4 +- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py index 6600b17d5cd..9ee3272aabd 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_video/hunyuan_video_15_transformer.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F @@ -27,6 +27,9 @@ from vllm_omni.diffusion.layers.rope import RotaryEmbedding from vllm_omni.diffusion.models.flux.flux_transformer import FeedForward +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + logger = init_logger(__name__) @@ -328,6 +331,8 @@ def __init__( out_bias: bool = True, eps: float = 1e-6, out_dim: int | None = None, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() @@ -346,6 +351,8 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv", ) self.to_out = nn.ModuleList( @@ -356,6 +363,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_out.0", ), nn.Identity(), # placeholder for dropout (none used) ] @@ -370,6 +379,8 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=added_proj_bias, + quant_config=quant_config, + prefix=f"{prefix}.add_kv_proj", ) self.to_add_out = RowParallelLinear( @@ -378,6 +389,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_add_out", ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -396,6 +409,8 @@ def forward( attention_mask: torch.Tensor | None = None, image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + # Ensure contiguous for FP8 quantized linear layers + hidden_states = hidden_states.contiguous() qkv, _ = self.to_qkv(hidden_states) q_size = self.to_qkv.num_heads * self.head_dim kv_size = self.to_qkv.num_kv_heads * self.head_dim @@ -416,6 +431,7 @@ def forward( key = self.rope(key, cos, sin) if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.contiguous() encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) add_q_size = self.add_kv_proj.num_heads * self.head_dim add_kv_size = self.add_kv_proj.num_kv_heads * self.head_dim @@ -469,6 +485,8 @@ def __init__( attention_head_dim: int, mlp_ratio: float = 4.0, qk_norm: str = "rms_norm", + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ) -> None: super().__init__() hidden_size = num_attention_heads * attention_head_dim @@ -484,13 +502,23 @@ def __init__( out_dim=hidden_size, bias=True, eps=1e-6, + quant_config=quant_config, + prefix=f"{prefix}.attn", ) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=hidden_size, dim_out=hidden_size, mult=mlp_ratio) + self.ff = FeedForward( + dim=hidden_size, dim_out=hidden_size, mult=mlp_ratio, quant_config=quant_config, prefix=f"{prefix}.ff" + ) self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=hidden_size, dim_out=hidden_size, mult=mlp_ratio) + self.ff_context = FeedForward( + dim=hidden_size, + dim_out=hidden_size, + mult=mlp_ratio, + quant_config=quant_config, + prefix=f"{prefix}.ff_context", + ) def forward( self, @@ -568,6 +596,7 @@ def __init__( target_size: int = 640, task_type: str = "i2v", use_meanflow: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -599,9 +628,14 @@ def __init__( self.transformer_blocks = nn.ModuleList( [ HunyuanVideo15TransformerBlock( - num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm + num_attention_heads, + attention_head_dim, + mlp_ratio=mlp_ratio, + qk_norm=qk_norm, + quant_config=quant_config, + prefix=f"transformer_blocks.{i}", ) - for _ in range(num_layers) + for i in range(num_layers) ] ) diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py index 6445bfee215..b007e00eed0 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5.py @@ -124,7 +124,9 @@ def __init__( self.scheduler._shift = od_config.flow_shift transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, HunyuanVideo15Transformer3DModel) - self.transformer = HunyuanVideo15Transformer3DModel(od_config=od_config, **transformer_kwargs) + self.transformer = HunyuanVideo15Transformer3DModel( + od_config=od_config, quant_config=od_config.quantization_config, **transformer_kwargs + ) # Check if model uses meanflow (distilled variants) self.use_meanflow = getattr(od_config.tf_model_config, "use_meanflow", False) diff --git a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py index c1acd1a895a..99b17bad424 100644 --- a/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py +++ b/vllm_omni/diffusion/models/hunyuan_video/pipeline_hunyuan_video_1_5_i2v.py @@ -153,7 +153,9 @@ def __init__( self.scheduler._shift = od_config.flow_shift transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, HunyuanVideo15Transformer3DModel) - self.transformer = HunyuanVideo15Transformer3DModel(od_config=od_config, **transformer_kwargs) + self.transformer = HunyuanVideo15Transformer3DModel( + od_config=od_config, quant_config=od_config.quantization_config, **transformer_kwargs + ) self.use_meanflow = getattr(od_config.tf_model_config, "use_meanflow", False) From 8aa8f997e30c67f55b4f436d3e3ad54bbfca91dc Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 03:07:26 +0800 Subject: [PATCH 09/24] [Quant] ModelOpt FP8 calibration script + stage config for HunyuanVideo-1.5 examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py: Offline calibration helper that produces a ModelOpt FP8 diffusers checkpoint for HunyuanVideo-1.5. Calibrates with 8 video prompts x 10 denoising steps, skips precision-sensitive layers (modulation, embeddings, output proj, token refiner) matching the #2728 / #2795 pattern, disables MHA quantizers by default (HV-1.5 self-attention degrades visibly under FP8 - see #2920). vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml: Stage config for serving the calibrated checkpoint via vllm-omni. Auto-detects ModelOpt metadata from the checkpoint (uses #2913's adapter). Signed-off-by: lishunyang --- .../quantize_hunyuanvideo_15_modelopt_fp8.py | 282 ++++++++++++++++++ .../hunyuan_video_15_dit_fp8.yaml | 33 ++ 2 files changed, 315 insertions(+) create mode 100644 examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py create mode 100644 vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py new file mode 100644 index 00000000000..e60b57f5543 --- /dev/null +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantize HunyuanVideo-1.5 (480p T2V) to a ModelOpt FP8 Hugging Face checkpoint. + +Calibrates the DiT transformer using a small video prompt set and exports a +diffusers-style directory whose transformer carries ModelOpt FP8 metadata. +The exported checkpoint is consumable by vllm-omni's ModelOpt FP8 adapter +(see vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py). + +Layers kept full precision match the #2728 / #2795 pattern: modulation, +AdaLayerNorm, entry/exit projections, embeddings, the token refiner path, +and final proj_out. MHA quantizers are off by default; HV-1.5 self-attention +empirically degrades under FP8 (see #2920 ablation). + +Example: + python examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py \\ + --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v \\ + --output ./hv15-480p-modelopt-fp8 \\ + --overwrite +""" + +from __future__ import annotations + +import argparse +import copy +import json +import re +import shutil +import sys +from pathlib import Path +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +DEFAULT_PROMPTS = [ + "A dog running across a field of golden wheat.", + "An astronaut riding a horse across the surface of Mars, red dust swirling, cinematic wide shot.", + "A hummingbird hovering in front of a vibrant red flower, slow motion, macro shot.", + "A crackling campfire at night under a starry sky, sparks rising into the dark.", + "An underwater shot of a coral reef with tropical fish swimming by, sun rays piercing the water.", + "A close-up of a blooming rose covered in morning dew, soft natural light.", + "A peaceful mountain village at dawn, mist rolling over the rooftops, cinematic establishing shot.", + "A skateboarder doing a kickflip in an urban plaza, slow motion, golden hour lighting.", +] + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--model", required=True, help="Input HV-1.5 diffusers directory or HF id.") + p.add_argument("--output", required=True, help="Output directory for the ModelOpt FP8 checkpoint.") + p.add_argument("--dtype", choices=("bfloat16", "float16"), default="bfloat16") + p.add_argument("--height", type=int, default=480) + p.add_argument("--width", type=int, default=832) + p.add_argument( + "--num-frames", + type=int, + default=33, + help="Frames per calibration sample. 33 matches the typical short benchmark.", + ) + p.add_argument("--guidance-scale", type=float, default=6.0) + p.add_argument( + "--calib-steps", + type=int, + default=10, + help="Denoising steps per calibration prompt (10 is enough for amax statistics).", + ) + p.add_argument("--calib-size", type=int, default=8, help="How many prompts to use for calibration.") + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--prompt", + action="append", + default=[], + help="Custom calibration prompt. Repeat to provide multiple.", + ) + p.add_argument( + "--quantize-mha", + action="store_true", + help="Enable FP8 attention K/V/softmax quantizers. Off by default — empirically degrades HV-1.5 video output.", + ) + p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") + return p + + +def _require_modelopt() -> tuple[Any, Any]: + try: + import modelopt.torch.quantization as mtq + from modelopt.torch.export import export_hf_checkpoint + except ModuleNotFoundError as exc: + raise SystemExit( + "NVIDIA ModelOpt is not installed. Install with:\n" + " pip install 'nvidia-modelopt[all]'\n" + f"Original error: {exc}" + ) from exc + return mtq, export_hf_checkpoint + + +def _ensure_paths(args: argparse.Namespace) -> tuple[str, Path]: + model_path = args.model + output_dir = Path(args.output).expanduser().resolve() + if output_dir.exists(): + if not args.overwrite: + raise SystemExit(f"Output directory already exists: {output_dir}\nPass --overwrite to replace it.") + shutil.rmtree(output_dir) + return model_path, output_dir + + +def _select_dtype(name: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float16": torch.float16}[name] + + +def _build_prompts(args: argparse.Namespace) -> list[str]: + prompts = args.prompt or DEFAULT_PROMPTS + if args.calib_size <= 0: + raise SystemExit("--calib-size must be positive.") + if len(prompts) < args.calib_size: + repeats = (args.calib_size + len(prompts) - 1) // len(prompts) + prompts = (prompts * repeats)[: args.calib_size] + return prompts[: args.calib_size] + + +# Layers to KEEP at full precision (mirror of the #2920 wiring + #2728/#2795 skip pattern). +# - x_embedder, image_embedder, context_embedder*, time_embed*, cond_type_embed: entry/embedding +# - norm_out, norm1*.linear, norm1_context*.linear, norm2*, norm2_context*: AdaLayerNorm modulation +# - proj_out: final output projection +# - token_refiner*: text-encoder refinement uses diffusers raw nn.Linear +def _filter_func_hv15(name: str) -> bool: + pattern = re.compile( + r"(proj_out.*|" + r".*(x_embedder|image_embedder|context_embedder|context_embedder_2|" + r"time_embed|cond_type_embed|" + r"norm_out|norm1\.linear|norm1_context\.linear|norm2|norm2_context|" + r"token_refiner).*)" + ) + return pattern.match(name) is not None + + +def _mha_filter_func(name: str) -> bool: + pattern = re.compile( + r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer|bmm2_output_quantizer).*" + ) + return pattern.match(name) is not None + + +def _disable_known_problematic_quantizers(mtq: Any, backbone: torch.nn.Module, *, quantize_mha: bool) -> None: + if not hasattr(mtq, "disable_quantizer"): + return + mtq.disable_quantizer(backbone, _filter_func_hv15) + if not quantize_mha: + mtq.disable_quantizer(backbone, _mha_filter_func) + + +def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: + pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) + if hasattr(pipe, "set_progress_bar_config"): + pipe.set_progress_bar_config(disable=True) + pipe.to("cuda") + return pipe + + +def _build_forward_loop(pipe: DiffusionPipeline, args: argparse.Namespace, prompts: list[str]): + generator = torch.Generator(device="cuda") + + def forward_loop(*_unused_args, **_unused_kwargs) -> None: + with torch.inference_mode(): + for idx, prompt in enumerate(prompts): + generator.manual_seed(args.seed + idx) + pipe( + prompt=prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.calib_steps, + guidance_scale=args.guidance_scale, + generator=generator, + output_type="latent", + ) + + return forward_loop + + +def _summarize_export(output_dir: Path) -> None: + cfg_path = output_dir / "transformer" / "config.json" + if not cfg_path.exists(): + print(f"[warn] {cfg_path} missing.", file=sys.stderr) + return + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print("[warn] No quantization_config in transformer/config.json.", file=sys.stderr) + return + print("Export summary:") + print(f" quant_method: {qc.get('quant_method')}") + print(f" quant_algo: {qc.get('quant_algo')}") + producer = qc.get("producer") + if isinstance(producer, dict): + print(f" producer: {producer.get('name')} {producer.get('version')}") + print(f" config path: {cfg_path}") + + +def _export_with_fallback( + pipe: DiffusionPipeline, + export_hf_checkpoint: Any, + model_path: str, + output_dir: Path, +) -> str: + try: + export_hf_checkpoint(pipe, export_dir=str(output_dir)) + return "pipeline" + except Exception as exc: + print( + f"[warn] Whole-pipeline export failed, falling back to transformer-only export: {exc}", + file=sys.stderr, + ) + if output_dir.exists(): + shutil.rmtree(output_dir) + # Resolve the source directory: a local path stays as-is; an HF id resolves via snapshot_download. + src = Path(model_path) + if not src.exists(): + from huggingface_hub import snapshot_download + + src = Path(snapshot_download(model_path)) + shutil.copytree(src, output_dir, dirs_exist_ok=True) + shutil.rmtree(output_dir / "transformer", ignore_errors=True) + export_hf_checkpoint(pipe.transformer, export_dir=str(output_dir / "transformer")) + return "transformer" + + +def main() -> None: + args = _build_parser().parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for ModelOpt FP8 quantization.") + + mtq, export_hf_checkpoint = _require_modelopt() + model_path, output_dir = _ensure_paths(args) + dtype = _select_dtype(args.dtype) + prompts = _build_prompts(args) + + print("Quantization plan:") + print(f" input: {args.model}") + print(f" output: {output_dir}") + print(f" dtype: {dtype}") + print(f" height/width: {args.height}x{args.width}") + print(f" num_frames: {args.num_frames}") + print(f" calib_size: {len(prompts)}") + print(f" calib_steps: {args.calib_steps}") + print(f" quantize_mha: {args.quantize_mha}") + + pipe = _load_pipeline(model_path, dtype) + backbone = pipe.transformer + + quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + + forward_loop = _build_forward_loop(pipe, args, prompts) + quantized = mtq.quantize(backbone, quant_config, forward_loop) + if quantized is not None: + pipe.transformer = quantized + backbone = quantized + + _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=args.quantize_mha) + + export_mode = _export_with_fallback(pipe, export_hf_checkpoint, model_path, output_dir) + print(f"Export mode: {export_mode}") + _summarize_export(output_dir) + + print("\nNext: validate the checkpoint with vllm-omni:") + print( + " python examples/offline_inference/text_to_video/text_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --stage-configs-path vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 6.0 --seed 42 \\\n" + " --output outputs/hv15_modelopt_fp8.mp4 \\\n" + " --enforce-eager" + ) + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml b/vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml new file mode 100644 index 00000000000..4220344b059 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml @@ -0,0 +1,33 @@ +# Stage config for running HunyuanVideo-1.5 DiT with ModelOpt FP8 auto-detect. +# Single GPU. Bump `tensor_parallel_size` and `devices` for multi-GPU TP. +# +# Use with a ModelOpt FP8 checkpoint (e.g. produced by +# scripts/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py). + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: HunyuanVideo15Pipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 1 + + final_output: true + final_output_type: video + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 From bd2a01deee42b0519f64f79dc16f76bf5937d128 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 03:50:29 +0800 Subject: [PATCH 10/24] [Quant] HV-1.5 calibration: handle Guider config (fall back if guidance_scale kwarg unsupported) HV-1.5's diffusers pipeline uses the new Guider abstraction (guider_config.json in the checkpoint) rather than a guidance_scale kwarg. Try setting it on the guider object once up front; in the per-prompt call, try with guidance_scale first and fall back without it on TypeError. Calibration only needs amax stats, so the exact CFG value isn't critical. Signed-off-by: lishunyang --- .../quantize_hunyuanvideo_15_modelopt_fp8.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py index e60b57f5543..2e9da999325 100644 --- a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -162,20 +162,36 @@ def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: def _build_forward_loop(pipe: DiffusionPipeline, args: argparse.Namespace, prompts: list[str]): generator = torch.Generator(device="cuda") + # Try to set guidance on the pipeline's guider object up front (modern + # diffusers HV-1.5 uses a Guider abstraction, not a per-call kwarg). Falls + # back silently — calibration uses whatever default the pipeline ships with. + guider = getattr(pipe, "guider", None) + if guider is not None and hasattr(guider, "guidance_scale"): + try: + guider.guidance_scale = args.guidance_scale + except Exception: + pass + + base_kwargs = dict( + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.calib_steps, + output_type="latent", + ) + def forward_loop(*_unused_args, **_unused_kwargs) -> None: with torch.inference_mode(): for idx, prompt in enumerate(prompts): generator.manual_seed(args.seed + idx) - pipe( - prompt=prompt, - height=args.height, - width=args.width, - num_frames=args.num_frames, - num_inference_steps=args.calib_steps, - guidance_scale=args.guidance_scale, - generator=generator, - output_type="latent", - ) + # Try with guidance_scale first; fall back without on TypeError + # for pipelines (like HV-1.5) that take CFG via guider config. + try: + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **base_kwargs) + except TypeError as exc: + if "guidance_scale" not in str(exc): + raise + pipe(prompt=prompt, generator=generator, **base_kwargs) return forward_loop From f76a50cb484460fb5e8697d6f22e3fb771df901c Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 04:04:31 +0800 Subject: [PATCH 11/24] [Quant] Add check_modelopt_fp8_export.py post-calibration verifier Three checks: (A) transformer/config.json has sane quantization_config, (B) safetensors contain FP8 tensors, (C) optional disk-size delta vs BF16. Run after the quantize_*_modelopt_fp8.py scripts to spot issues before attempting to serve. Signed-off-by: lishunyang --- .../quantization/check_modelopt_fp8_export.py | 183 ++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 examples/quantization/check_modelopt_fp8_export.py diff --git a/examples/quantization/check_modelopt_fp8_export.py b/examples/quantization/check_modelopt_fp8_export.py new file mode 100644 index 00000000000..87668f10036 --- /dev/null +++ b/examples/quantization/check_modelopt_fp8_export.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Verify a ModelOpt FP8 diffusers checkpoint exported by +quantize_hunyuanvideo_15_modelopt_fp8.py (or any sibling quantize_*.py). + +Three checks: + A. transformer/config.json has a sane quantization_config block. + B. transformer/*.safetensors contains FP8 (float8_e4m3fn) tensors. + C. transformer disk size is materially smaller than a BF16 baseline. + +Example: + python examples/quantization/check_modelopt_fp8_export.py \\ + --output ./hv15-480p-modelopt-fp8 + + # Optional: compare disk size against a local or HF BF16 baseline. + python examples/quantization/check_modelopt_fp8_export.py \\ + --output ./hv15-480p-modelopt-fp8 \\ + --baseline hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v +""" + +from __future__ import annotations + +import argparse +import json +import sys +from collections import Counter +from pathlib import Path + + +def _check_config(transformer_dir: Path) -> int: + """Returns 0 on pass, 1 on fail. Prints findings.""" + cfg_path = transformer_dir / "config.json" + if not cfg_path.exists(): + print(f"[FAIL] {cfg_path} missing.") + return 1 + + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print(f"[FAIL] No `quantization_config` block in {cfg_path}.") + return 1 + + print(f"[A] quantization_config from {cfg_path}:") + print(json.dumps(qc, indent=2)) + + issues = [] + if qc.get("quant_method") != "modelopt": + issues.append(f"quant_method={qc.get('quant_method')!r} (expected 'modelopt')") + if qc.get("quant_algo") != "FP8": + issues.append(f"quant_algo={qc.get('quant_algo')!r} (expected 'FP8' — vllm-omni adapter may not auto-detect)") + + if issues: + print("[A] WARN — config looks incomplete:") + for issue in issues: + print(f" - {issue}") + return 2 + print("[A] PASS — config looks correct.") + return 0 + + +def _check_safetensors(transformer_dir: Path) -> int: + """Returns 0 on pass, 1 on fail. Counts dtypes across all *.safetensors.""" + try: + from safetensors import safe_open + except ImportError: + print("[B] SKIP — safetensors not installed.") + return 0 + + files = sorted(transformer_dir.glob("*.safetensors")) + if not files: + print(f"[FAIL] No *.safetensors in {transformer_dir}.") + return 1 + + counts: Counter[str] = Counter() + sample_fp8_keys: list[str] = [] + sample_scale_keys: list[str] = [] + for f in files: + with safe_open(f, framework="pt", device="cpu") as h: + for k in h.keys(): + t = h.get_tensor(k) + dtype = str(t.dtype) + counts[dtype] += 1 + if "float8" in dtype and len(sample_fp8_keys) < 5: + sample_fp8_keys.append(k) + if k.endswith(("_scale", ".weight_scale", ".input_scale")) and len(sample_scale_keys) < 5: + sample_scale_keys.append(k) + + print(f"\n[B] Tensor dtype counts across {len(files)} safetensors file(s):") + for dtype, count in sorted(counts.items(), key=lambda kv: -kv[1]): + print(f" {dtype:30s} {count:>6d}") + + fp8_count = sum(c for d, c in counts.items() if "float8" in d) + if fp8_count == 0: + print("[B] FAIL — no FP8 tensors found. Calibration likely did not actually quantize the weights.") + return 1 + + print(f"[B] PASS — {fp8_count} FP8 tensors present.") + if sample_fp8_keys: + print(f" sample FP8 tensors: {sample_fp8_keys[:3]}") + if sample_scale_keys: + print(f" sample scale tensors: {sample_scale_keys[:3]}") + return 0 + + +def _disk_size_gib(p: Path) -> float: + return sum(f.stat().st_size for f in p.rglob("*") if f.is_file()) / (1024**3) + + +def _check_size_vs_baseline(transformer_dir: Path, baseline: str | None) -> int: + """Returns 0 always (informational only).""" + fp8_size = _disk_size_gib(transformer_dir) + print(f"\n[C] FP8 transformer disk size: {fp8_size:.2f} GiB") + + if baseline is None: + print("[C] SKIP — pass --baseline to compare against BF16.") + return 0 + + baseline_path = Path(baseline) + if not baseline_path.exists(): + # Try HF download. + try: + from huggingface_hub import snapshot_download + except ImportError: + print("[C] SKIP — huggingface_hub not installed and baseline not a local path.") + return 0 + print(f" Downloading baseline transformer from HF: {baseline}") + baseline_path = Path(snapshot_download(baseline, allow_patterns=["transformer/*"])) + + bf16_dir = baseline_path / "transformer" if (baseline_path / "transformer").exists() else baseline_path + bf16_size = _disk_size_gib(bf16_dir) + if bf16_size == 0: + print(f"[C] WARN — baseline transformer dir empty: {bf16_dir}") + return 0 + + reduction = (1 - fp8_size / bf16_size) * 100 + print(f"[C] BF16 baseline transformer disk size: {bf16_size:.2f} GiB ({bf16_dir})") + print(f"[C] Disk reduction: {reduction:.1f}% (FP8 transformer is {fp8_size / bf16_size:.0%} of BF16)") + if reduction < 30: + print("[C] WARN — FP8 should typically reduce disk by ~40-50%; <30% suggests partial quantization.") + return 0 + + +def main() -> None: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--output", required=True, help="Path to the exported ModelOpt FP8 checkpoint root.") + p.add_argument( + "--baseline", + default=None, + help="Optional BF16 baseline (local diffusers dir or HF id) for disk-size comparison.", + ) + args = p.parse_args() + + out_root = Path(args.output).expanduser().resolve() + transformer_dir = out_root / "transformer" + if not transformer_dir.exists(): + print(f"[FAIL] {transformer_dir} does not exist.") + sys.exit(1) + + print(f"Checking: {out_root}\n") + + fail = 0 + fail |= _check_config(transformer_dir) + fail |= _check_safetensors(transformer_dir) + _check_size_vs_baseline(transformer_dir, args.baseline) + + print() + if fail == 0: + print("=" * 60) + print("ALL CHECKS PASSED — checkpoint looks ready for vllm-omni serving.") + elif fail == 1: + print("=" * 60) + print("FAILURES detected — calibration may need to be re-run.") + sys.exit(1) + else: + print("=" * 60) + print("WARNINGS only — checkpoint may serve but with caveats. See [A] above.") + + +if __name__ == "__main__": + main() From b2f4522bdb660f0240e29a08b53d75f08c25d4e2 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 04:08:13 +0800 Subject: [PATCH 12/24] [Quant] check script: read safetensors header for dtype (not get_tensor view) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch's get_tensor() returns FP8 storage as bf16 views on some safetensors versions, giving false negatives. Read the on-disk dtype from the header directly — that's what actually determines whether the checkpoint is FP8. Signed-off-by: lishunyang --- .../quantization/check_modelopt_fp8_export.py | 67 ++++++++++++------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/examples/quantization/check_modelopt_fp8_export.py b/examples/quantization/check_modelopt_fp8_export.py index 87668f10036..dd4c2ed43e8 100644 --- a/examples/quantization/check_modelopt_fp8_export.py +++ b/examples/quantization/check_modelopt_fp8_export.py @@ -61,47 +61,62 @@ def _check_config(transformer_dir: Path) -> int: return 0 -def _check_safetensors(transformer_dir: Path) -> int: - """Returns 0 on pass, 1 on fail. Counts dtypes across all *.safetensors.""" - try: - from safetensors import safe_open - except ImportError: - print("[B] SKIP — safetensors not installed.") - return 0 +def _read_safetensors_header(path: Path) -> dict: + """Read the JSON header of a safetensors file. Bypass-safe — doesn't materialize tensors. + + Returns {tensor_name: {'dtype': 'F8_E4M3', 'shape': [...], 'data_offsets': [...]}}. + Header dtype strings: F8_E4M3, F8_E5M2, BF16, F16, F32, F64, I8, I16, I32, I64, BOOL, U8, ... + """ + import struct + + with open(path, "rb") as f: + header_len = struct.unpack(" int: + """Returns 0 on pass, 1 on fail. Reads on-disk dtype from the safetensors header.""" files = sorted(transformer_dir.glob("*.safetensors")) if not files: print(f"[FAIL] No *.safetensors in {transformer_dir}.") return 1 - counts: Counter[str] = Counter() + header_dtype_counts: Counter[str] = Counter() sample_fp8_keys: list[str] = [] sample_scale_keys: list[str] = [] for f in files: - with safe_open(f, framework="pt", device="cpu") as h: - for k in h.keys(): - t = h.get_tensor(k) - dtype = str(t.dtype) - counts[dtype] += 1 - if "float8" in dtype and len(sample_fp8_keys) < 5: - sample_fp8_keys.append(k) - if k.endswith(("_scale", ".weight_scale", ".input_scale")) and len(sample_scale_keys) < 5: - sample_scale_keys.append(k) - - print(f"\n[B] Tensor dtype counts across {len(files)} safetensors file(s):") - for dtype, count in sorted(counts.items(), key=lambda kv: -kv[1]): - print(f" {dtype:30s} {count:>6d}") - - fp8_count = sum(c for d, c in counts.items() if "float8" in d) + try: + header = _read_safetensors_header(f) + except Exception as exc: + print(f"[B] WARN — could not parse header of {f}: {exc}") + continue + for k, info in header.items(): + dtype = info.get("dtype", "?") + header_dtype_counts[dtype] += 1 + if dtype.startswith("F8") and len(sample_fp8_keys) < 5: + sample_fp8_keys.append(k) + if k.endswith(("_scale", ".weight_scale", ".input_scale", "_scale_inv")) and len(sample_scale_keys) < 5: + sample_scale_keys.append(k) + + print(f"\n[B] On-disk dtype counts across {len(files)} safetensors file(s) (from header, not get_tensor):") + for dtype, count in sorted(header_dtype_counts.items(), key=lambda kv: -kv[1]): + marker = " <-- FP8" if dtype.startswith("F8") else "" + print(f" {dtype:10s} {count:>6d}{marker}") + + fp8_count = sum(c for d, c in header_dtype_counts.items() if d.startswith("F8")) if fp8_count == 0: - print("[B] FAIL — no FP8 tensors found. Calibration likely did not actually quantize the weights.") + print("[B] FAIL — no FP8 tensors on disk. Calibration likely did not actually quantize the weights.") return 1 - print(f"[B] PASS — {fp8_count} FP8 tensors present.") + print(f"[B] PASS — {fp8_count} FP8 tensors stored on disk.") if sample_fp8_keys: - print(f" sample FP8 tensors: {sample_fp8_keys[:3]}") + print(f" sample FP8 tensors: {sample_fp8_keys[:3]}") if sample_scale_keys: print(f" sample scale tensors: {sample_scale_keys[:3]}") + print(" (Note: torch's get_tensor() may return these as bf16 views on some versions —") + print(" irrelevant; vLLM's loader uses native FP8 ops.)") return 0 From 2b3883839e84b88db60dc80369481e4482cf57ed Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 04:13:07 +0800 Subject: [PATCH 13/24] [Quant] Force FP8 weight serialization + patch quant_algo for HV-1.5 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The default export_hf_checkpoint() doesn't actually serialize weights as FP8 for unknown model types like HunyuanVideo15Transformer3DModel — it saves BF16 placeholders. The HunyuanImage-3 calibration helper hit the same bug. Three changes: - Manually call modelopt.torch.export.unified_export_hf._export_quantized_weight per-module to convert in-memory tensors to actual FP8. - Save the pipeline by hand (copy source minus transformer/, then save the quantized transformer with hide_quantizers_from_state_dict). - Patch transformer/config.json to inject quant_algo: FP8 + config_groups so vllm-omni's adapter (#2913) auto-detects it. Signed-off-by: lishunyang --- .../quantize_hunyuanvideo_15_modelopt_fp8.py | 161 +++++++++++++++--- 1 file changed, 133 insertions(+), 28 deletions(-) diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py index 2e9da999325..6e2f3f31b40 100644 --- a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -83,17 +83,16 @@ def _build_parser() -> argparse.ArgumentParser: return p -def _require_modelopt() -> tuple[Any, Any]: +def _require_modelopt() -> Any: try: import modelopt.torch.quantization as mtq - from modelopt.torch.export import export_hf_checkpoint except ModuleNotFoundError as exc: raise SystemExit( "NVIDIA ModelOpt is not installed. Install with:\n" " pip install 'nvidia-modelopt[all]'\n" f"Original error: {exc}" ) from exc - return mtq, export_hf_checkpoint + return mtq def _ensure_paths(args: argparse.Namespace) -> tuple[str, Path]: @@ -216,32 +215,125 @@ def _summarize_export(output_dir: Path) -> None: print(f" config path: {cfg_path}") -def _export_with_fallback( +def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtype) -> int: + """Convert in-memory weights of quantized modules to actual FP8 storage. + + `export_hf_checkpoint` skips this step for unknown model types (HV-1.5 isn't + in ModelOpt's recognized-model registry), so we must call the per-weight + export helper ourselves. Same workaround as the HunyuanImage-3 calibration + helper. + """ + from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + get_quantization_format, + quantizer_attr_names, + weight_attr_names, + ) + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + exported = 0 + for name, module in backbone.named_modules(): + try: + quantization_format = get_quantization_format(module) + except Exception as exc: + print(f"[warn] Could not inspect quantization format for {name}: {exc}", file=sys.stderr) + continue + if quantization_format == QUANTIZATION_NONE: + continue + for weight_name in weight_attr_names(module): + quantizer_attrs = quantizer_attr_names(weight_name) + weight_quantizer = getattr(module, quantizer_attrs.weight_quantizer, None) + if weight_quantizer is None or not getattr(weight_quantizer, "is_enabled", False): + continue + _export_quantized_weight(module, dtype, weight_name) + exported += 1 + return exported + + +def _hv15_quant_config_block() -> dict: + """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). + + Same shape as the HunyuanImage-3 author's _hunyuan_quant_config(). + """ + return { + "config_groups": { + "group_0": { + "input_activations": {"dynamic": False, "num_bits": 8, "type": "float"}, + "weights": {"dynamic": False, "num_bits": 8, "type": "float"}, + "targets": ["Linear"], + } + }, + "ignore": [ + "context_embedder*", + "context_embedder_2*", + "cond_type_embed*", + "image_embedder*", + "norm1.linear*", + "norm1_context.linear*", + "norm2*", + "norm2_context*", + "norm_out*", + "proj_out*", + "time_embed*", + "token_refiner*", + "x_embedder*", + ], + "producer": {"name": "modelopt"}, + "quant_algo": "FP8", + "quant_method": "modelopt", + } + + +def _patch_quant_config(output_dir: Path) -> None: + """Inject quant_algo: FP8 + config_groups into transformer/config.json so + vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8.""" + cfg_path = output_dir / "transformer" / "config.json" + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + new_qc = _hv15_quant_config_block() + existing = cfg.get("quantization_config") + if isinstance(existing, dict): + producer = existing.get("producer") + if isinstance(producer, dict): + new_qc["producer"] = producer + + cfg["quantization_config"] = new_qc + with cfg_path.open("w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2) + + +def _save_pipeline_with_fp8_transformer( pipe: DiffusionPipeline, - export_hf_checkpoint: Any, model_path: str, output_dir: Path, -) -> str: - try: - export_hf_checkpoint(pipe, export_dir=str(output_dir)) - return "pipeline" - except Exception as exc: - print( - f"[warn] Whole-pipeline export failed, falling back to transformer-only export: {exc}", - file=sys.stderr, - ) - if output_dir.exists(): - shutil.rmtree(output_dir) - # Resolve the source directory: a local path stays as-is; an HF id resolves via snapshot_download. - src = Path(model_path) - if not src.exists(): - from huggingface_hub import snapshot_download + max_shard_size: str = "5GB", +) -> None: + """Save the pipeline with the (now FP8) transformer. + + Copies the source directory verbatim except for `transformer/`, then + saves the transformer with quantizers hidden so the state dict contains + only the FP8 weights + scale tensors. + """ + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict - src = Path(snapshot_download(model_path)) - shutil.copytree(src, output_dir, dirs_exist_ok=True) - shutil.rmtree(output_dir / "transformer", ignore_errors=True) - export_hf_checkpoint(pipe.transformer, export_dir=str(output_dir / "transformer")) - return "transformer" + src = Path(model_path) + if not src.exists(): + from huggingface_hub import snapshot_download + + src = Path(snapshot_download(model_path)) + + if output_dir.exists(): + shutil.rmtree(output_dir) + shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer")) + + transformer_out = output_dir / "transformer" + with hide_quantizers_from_state_dict(pipe): + pipe.transformer.save_pretrained( + str(transformer_out), + safe_serialization=True, + max_shard_size=max_shard_size, + ) def main() -> None: @@ -249,7 +341,7 @@ def main() -> None: if not torch.cuda.is_available(): raise SystemExit("CUDA is required for ModelOpt FP8 quantization.") - mtq, export_hf_checkpoint = _require_modelopt() + mtq = _require_modelopt() model_path, output_dir = _ensure_paths(args) dtype = _select_dtype(args.dtype) prompts = _build_prompts(args) @@ -277,8 +369,21 @@ def main() -> None: _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=args.quantize_mha) - export_mode = _export_with_fallback(pipe, export_hf_checkpoint, model_path, output_dir) - print(f"Export mode: {export_mode}") + print("\nForcing FP8 weight serialization (HV-1.5 isn't in ModelOpt's recognized-model registry,") + print("so we have to call the per-weight export helper ourselves)...") + exported = _force_export_quantized_weights(backbone, dtype) + print(f" -> {exported} weights converted to FP8 in memory") + if exported == 0: + raise SystemExit( + "No quantized weights were exported. Calibration may have skipped every layer " + "(check the disable_quantizer regex) or `mtq.quantize` did not actually wrap any " + "weight quantizers." + ) + + print("\nSaving pipeline with FP8 transformer...") + _save_pipeline_with_fp8_transformer(pipe, model_path, output_dir) + _patch_quant_config(output_dir) + print(f"Saved to: {output_dir}") _summarize_export(output_dir) print("\nNext: validate the checkpoint with vllm-omni:") From d876c588421fc95ab33c189f956344d06b9c473f Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 04:23:03 +0800 Subject: [PATCH 14/24] [Quant] hide_quantizers_from_state_dict: pass transformer (nn.Module), not pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Diffusers pipelines are ConfigMixin, not nn.Module — they don't have .named_modules(). Pass pipe.transformer directly. Signed-off-by: lishunyang --- .../quantization/quantize_hunyuanvideo_15_modelopt_fp8.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py index 6e2f3f31b40..d7ad1a3f1a8 100644 --- a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -328,7 +328,9 @@ def _save_pipeline_with_fp8_transformer( shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer")) transformer_out = output_dir / "transformer" - with hide_quantizers_from_state_dict(pipe): + # `hide_quantizers_from_state_dict` walks named_modules(); pass the actual + # nn.Module (transformer), not the diffusers Pipeline wrapper. + with hide_quantizers_from_state_dict(pipe.transformer): pipe.transformer.save_pretrained( str(transformer_out), safe_serialization=True, From 737db254f122f18826d4bac8752caca5a794190f Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 04:39:24 +0800 Subject: [PATCH 15/24] [Quant] Fix calibrator's 'next' hint: text_to_video.py uses --quantization fp8, not --stage-configs-path Signed-off-by: lishunyang --- .../quantization/quantize_hunyuanvideo_15_modelopt_fp8.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py index d7ad1a3f1a8..8a1fedd94d7 100644 --- a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -392,13 +392,17 @@ def main() -> None: print( " python examples/offline_inference/text_to_video/text_to_video.py \\\n" f" --model {output_dir} \\\n" - " --stage-configs-path vllm_omni/model_executor/stage_configs/hunyuan_video_15_dit_fp8.yaml \\\n" + " --quantization fp8 \\\n" " --prompt 'A dog running across a field of golden wheat.' \\\n" f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" " --num-inference-steps 30 --guidance-scale 6.0 --seed 42 \\\n" " --output outputs/hv15_modelopt_fp8.mp4 \\\n" " --enforce-eager" ) + print( + "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " + "checkpoint's config.json has modelopt metadata.)" + ) if __name__ == "__main__": From 26b53f3bd94dd84f5bbfb4b021e575bee3e70e90 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 05:04:29 +0800 Subject: [PATCH 16/24] [Quant] Add --weight-block-size to switch HV-1.5 ModelOpt FP8 to per-block MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When --weight-block-size 'M,N' is given, override the weight quantizer with block_sizes={-1: N, -2: M} so each linear gets a (out//M, in//N) scale tensor instead of a scalar. Patched config_groups advertises strategy='block' + block_structure='MxN' so consumers know what to expect. Static FP8 is exempt from upstream vLLM's online block-wise gate, so this just works at serving time via #2913's adapter. Default behavior unchanged (per-tensor) — pass --weight-block-size 128,128 to opt in. Signed-off-by: lishunyang --- .../quantize_hunyuanvideo_15_modelopt_fp8.py | 67 +++++++++++++++---- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py index 8a1fedd94d7..95a5cf91364 100644 --- a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -79,10 +79,27 @@ def _build_parser() -> argparse.ArgumentParser: action="store_true", help="Enable FP8 attention K/V/softmax quantizers. Off by default — empirically degrades HV-1.5 video output.", ) + p.add_argument( + "--weight-block-size", + type=str, + default=None, + help="Per-block weight quantization as 'M,N' (e.g. '128,128' for 128x128 tiles). " + "Default: per-tensor (one scale per linear). Block-wise typically gives tighter quality at " + "negligible memory cost. Static FP8 is exempt from upstream vLLM's online block-wise gate.", + ) p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") return p +def _parse_block_size(spec: str | None) -> list[int] | None: + if spec is None: + return None + parts = [int(x) for x in spec.split(",") if x.strip()] + if len(parts) != 2: + raise SystemExit(f"--weight-block-size must be 'M,N' (2 ints), got {spec!r}") + return parts + + def _require_modelopt() -> Any: try: import modelopt.torch.quantization as mtq @@ -250,16 +267,22 @@ def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtyp return exported -def _hv15_quant_config_block() -> dict: +def _hv15_quant_config_block(weight_block_size: list[int] | None = None) -> dict: """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). - Same shape as the HunyuanImage-3 author's _hunyuan_quant_config(). + Same shape as the HunyuanImage-3 author's _hunyuan_quant_config(). When + `weight_block_size` is given, advertise block-wise weight quantization in + the saved metadata (so consumers know to expect multi-element scale tensors). """ + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} + if weight_block_size is not None: + weights_cfg["strategy"] = "block" + weights_cfg["block_structure"] = f"{weight_block_size[0]}x{weight_block_size[1]}" return { "config_groups": { "group_0": { "input_activations": {"dynamic": False, "num_bits": 8, "type": "float"}, - "weights": {"dynamic": False, "num_bits": 8, "type": "float"}, + "weights": weights_cfg, "targets": ["Linear"], } }, @@ -284,14 +307,14 @@ def _hv15_quant_config_block() -> dict: } -def _patch_quant_config(output_dir: Path) -> None: +def _patch_quant_config(output_dir: Path, weight_block_size: list[int] | None = None) -> None: """Inject quant_algo: FP8 + config_groups into transformer/config.json so vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8.""" cfg_path = output_dir / "transformer" / "config.json" with cfg_path.open(encoding="utf-8") as f: cfg = json.load(f) - new_qc = _hv15_quant_config_block() + new_qc = _hv15_quant_config_block(weight_block_size=weight_block_size) existing = cfg.get("quantization_config") if isinstance(existing, dict): producer = existing.get("producer") @@ -349,19 +372,35 @@ def main() -> None: prompts = _build_prompts(args) print("Quantization plan:") - print(f" input: {args.model}") - print(f" output: {output_dir}") - print(f" dtype: {dtype}") - print(f" height/width: {args.height}x{args.width}") - print(f" num_frames: {args.num_frames}") - print(f" calib_size: {len(prompts)}") - print(f" calib_steps: {args.calib_steps}") - print(f" quantize_mha: {args.quantize_mha}") + weight_block_size = _parse_block_size(args.weight_block_size) + + print(f" input: {args.model}") + print(f" output: {output_dir}") + print(f" dtype: {dtype}") + print(f" height/width: {args.height}x{args.width}") + print(f" num_frames: {args.num_frames}") + print(f" calib_size: {len(prompts)}") + print(f" calib_steps: {args.calib_steps}") + print(f" quantize_mha: {args.quantize_mha}") + print( + f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" + ) pipe = _load_pipeline(model_path, dtype) backbone = pipe.transformer quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + if weight_block_size is not None: + # Switch from per-tensor (default) to block-wise weight quantization. + # ModelOpt's wildcard "*weight_quantizer" matches every linear's weight quantizer. + quant_config["quant_cfg"]["*weight_quantizer"] = { + "num_bits": (4, 3), # E4M3 (FP8 weights, same as default) + "block_sizes": {-1: weight_block_size[1], -2: weight_block_size[0]}, + } + print( + f" -> overriding weight quantizer with block_sizes={weight_block_size} " + f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" + ) forward_loop = _build_forward_loop(pipe, args, prompts) quantized = mtq.quantize(backbone, quant_config, forward_loop) @@ -384,7 +423,7 @@ def main() -> None: print("\nSaving pipeline with FP8 transformer...") _save_pipeline_with_fp8_transformer(pipe, model_path, output_dir) - _patch_quant_config(output_dir) + _patch_quant_config(output_dir, weight_block_size=weight_block_size) print(f"Saved to: {output_dir}") _summarize_export(output_dir) From 23af41fd283241ebe209c8f9ab6bbc1a26e7810a Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 05:24:26 +0800 Subject: [PATCH 17/24] [Quant] check script: classify weight_scale granularity (per-tensor vs per-block) Reads shape info from safetensors header and classifies the checkpoint as per-tensor / per-channel / per-block based on whether weight_scale tensors are scalar, 1-D, or N-D. Helps verify --weight-block-size actually took effect (or if ModelOpt silently flattened to per-tensor). Signed-off-by: lishunyang --- .../quantization/check_modelopt_fp8_export.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/examples/quantization/check_modelopt_fp8_export.py b/examples/quantization/check_modelopt_fp8_export.py index dd4c2ed43e8..9ec8d90ab75 100644 --- a/examples/quantization/check_modelopt_fp8_export.py +++ b/examples/quantization/check_modelopt_fp8_export.py @@ -76,6 +76,23 @@ def _read_safetensors_header(path: Path) -> dict: return header +def _classify_weight_scale_granularity(weight_scale_shapes: list[list[int]]) -> str: + """Infer per-tensor vs per-channel vs per-block from sample weight_scale shapes.""" + if not weight_scale_shapes: + return "no weight_scale tensors found" + scalar = sum(1 for s in weight_scale_shapes if len(s) == 0 or (len(s) == 1 and s[0] == 1)) + per_channel = sum(1 for s in weight_scale_shapes if len(s) == 1 and s[0] > 1) + per_block = sum(1 for s in weight_scale_shapes if len(s) >= 2 and all(x > 1 for x in s)) + total = len(weight_scale_shapes) + if scalar == total: + return "per-tensor (all scalar scales)" + if per_channel == total: + return "per-channel (all 1-D scales)" + if per_block == total: + return "per-block (all N-D scales)" + return f"mixed: scalar={scalar}, per-channel={per_channel}, per-block={per_block} of {total}" + + def _check_safetensors(transformer_dir: Path) -> int: """Returns 0 on pass, 1 on fail. Reads on-disk dtype from the safetensors header.""" files = sorted(transformer_dir.glob("*.safetensors")) @@ -86,6 +103,8 @@ def _check_safetensors(transformer_dir: Path) -> int: header_dtype_counts: Counter[str] = Counter() sample_fp8_keys: list[str] = [] sample_scale_keys: list[str] = [] + weight_scale_shapes: list[list[int]] = [] + sample_weight_scale_entries: list[tuple[str, list[int]]] = [] for f in files: try: header = _read_safetensors_header(f) @@ -99,6 +118,10 @@ def _check_safetensors(transformer_dir: Path) -> int: sample_fp8_keys.append(k) if k.endswith(("_scale", ".weight_scale", ".input_scale", "_scale_inv")) and len(sample_scale_keys) < 5: sample_scale_keys.append(k) + if k.endswith(".weight_scale"): + weight_scale_shapes.append(info.get("shape", [])) + if len(sample_weight_scale_entries) < 5: + sample_weight_scale_entries.append((k, info.get("shape", []))) print(f"\n[B] On-disk dtype counts across {len(files)} safetensors file(s) (from header, not get_tensor):") for dtype, count in sorted(header_dtype_counts.items(), key=lambda kv: -kv[1]): @@ -117,6 +140,11 @@ def _check_safetensors(transformer_dir: Path) -> int: print(f" sample scale tensors: {sample_scale_keys[:3]}") print(" (Note: torch's get_tensor() may return these as bf16 views on some versions —") print(" irrelevant; vLLM's loader uses native FP8 ops.)") + + # Weight-scale granularity — per-tensor (scalar) vs per-channel (1-D) vs per-block (N-D). + print(f"\n weight_scale granularity: {_classify_weight_scale_granularity(weight_scale_shapes)}") + for key, shape in sample_weight_scale_entries[:3]: + print(f" {key}: shape {shape}") return 0 From 9c2216380b64aee435c432af604ab1eef244bf20 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 05:26:20 +0800 Subject: [PATCH 18/24] [Quant] check: count only meaningful (>1) dims when classifying scale granularity ModelOpt block-wise produces shapes like [16, 1, 16, 1] where size-1 dims are broadcasting axes. Classify by non-unity dim count: 0=per-tensor, 1=per-channel, 2+=per-block. Signed-off-by: lishunyang --- .../quantization/check_modelopt_fp8_export.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/examples/quantization/check_modelopt_fp8_export.py b/examples/quantization/check_modelopt_fp8_export.py index 9ec8d90ab75..eafa3b62fb2 100644 --- a/examples/quantization/check_modelopt_fp8_export.py +++ b/examples/quantization/check_modelopt_fp8_export.py @@ -77,20 +77,29 @@ def _read_safetensors_header(path: Path) -> dict: def _classify_weight_scale_granularity(weight_scale_shapes: list[list[int]]) -> str: - """Infer per-tensor vs per-channel vs per-block from sample weight_scale shapes.""" + """Infer per-tensor vs per-channel vs per-block from sample weight_scale shapes. + + ModelOpt block-wise produces shapes like `[16, 1, 16, 1]` (broadcasting dims of 1 + interleaved with block-count dims). We count "meaningful" dims — ones with size > 1 — + and classify: 0 meaningful dims = per-tensor (scalar), 1 = per-channel, 2+ = per-block. + """ if not weight_scale_shapes: return "no weight_scale tensors found" - scalar = sum(1 for s in weight_scale_shapes if len(s) == 0 or (len(s) == 1 and s[0] == 1)) - per_channel = sum(1 for s in weight_scale_shapes if len(s) == 1 and s[0] > 1) - per_block = sum(1 for s in weight_scale_shapes if len(s) >= 2 and all(x > 1 for x in s)) + + def meaningful_dims(shape: list[int]) -> int: + return sum(1 for d in shape if d > 1) + + per_tensor = sum(1 for s in weight_scale_shapes if meaningful_dims(s) == 0) + per_channel = sum(1 for s in weight_scale_shapes if meaningful_dims(s) == 1) + per_block = sum(1 for s in weight_scale_shapes if meaningful_dims(s) >= 2) total = len(weight_scale_shapes) - if scalar == total: + if per_tensor == total: return "per-tensor (all scalar scales)" if per_channel == total: - return "per-channel (all 1-D scales)" + return "per-channel (1 meaningful dim)" if per_block == total: - return "per-block (all N-D scales)" - return f"mixed: scalar={scalar}, per-channel={per_channel}, per-block={per_block} of {total}" + return "per-block (2+ meaningful dims — e.g. [M//bm, 1, N//bn, 1] for tiles)" + return f"mixed: per-tensor={per_tensor}, per-channel={per_channel}, per-block={per_block} of {total}" def _check_safetensors(transformer_dir: Path) -> int: From 4047d1c0ffbe8f3c0edeec08ed034231d9eb817e Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 05:30:25 +0800 Subject: [PATCH 19/24] [Quant] Wire quant_config through Wan2.2 DiT (extracted from #2920) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Threads quant_config / prefix through WanSelfAttention, WanCrossAttention, WanFeedForward (+ ColumnParallelGELU), WanTransformerBlock, and WanTransformer3DModel / WanVACETransformer3DModel, plus the four pipelines (T2V / I2V / TI2V / VACE). Modulation (scale_shift_table), patch_embedding (Conv3d), time/text/image embedders, and proj_out stay full precision. All attention + FFN linears receive quant_config so the ModelOpt FP8 adapter from #2913 can bind per-layer scales at load time. The aggressive skip patterns from #2920 (attn1/attn2 quant_config=None) are NOT applied here — that was an online-FP8 quality workaround; static calibration handles it. Signed-off-by: lishunyang --- .../models/wan2_2/pipeline_wan2_2.py | 6 +- .../models/wan2_2/pipeline_wan2_2_i2v.py | 8 +- .../models/wan2_2/pipeline_wan2_2_ti2v.py | 4 +- .../models/wan2_2/pipeline_wan2_2_vace.py | 6 +- .../models/wan2_2/wan2_2_transformer.py | 75 +++++++++++++++++-- .../models/wan2_2/wan2_2_vace_transformer.py | 23 +++++- 6 files changed, 106 insertions(+), 16 deletions(-) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 652425d5097..df45034a258 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -119,9 +119,11 @@ def load_transformer_config(model_path: str, subfolder: str = "transformer", loc return {} -def create_transformer_from_config(config: dict) -> WanTransformer3DModel: +def create_transformer_from_config(config: dict, quant_config=None) -> WanTransformer3DModel: """Create WanTransformer3DModel from config dict.""" kwargs = {} + if quant_config is not None: + kwargs["quant_config"] = quant_config if "patch_size" in config: kwargs["patch_size"] = tuple(config["patch_size"]) @@ -374,7 +376,7 @@ def __init__( def _create_transformer(self, config: dict) -> WanTransformer3DModel: """Create a transformer from a config dict. Subclasses may override.""" - return create_transformer_from_config(config) + return create_transformer_from_config(config, quant_config=self.od_config.quantization_config) @property def guidance_scale(self): diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 24e8965a39e..dac78555c41 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -244,10 +244,14 @@ def __init__( # Transformers (weights loaded via load_weights) # Load config from model directory or HF Hub to get correct in_channels for I2V models transformer_config = load_transformer_config(model, "transformer", local_files_only) - self.transformer = create_transformer_from_config(transformer_config) + self.transformer = create_transformer_from_config( + transformer_config, quant_config=od_config.quantization_config + ) if self.has_transformer_2: transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) - self.transformer_2 = create_transformer_from_config(transformer_2_config) + self.transformer_2 = create_transformer_from_config( + transformer_2_config, quant_config=od_config.quantization_config + ) else: self.transformer_2 = None diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index dba76ba8af8..671a16b6fb6 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -202,7 +202,9 @@ def __init__( # Single transformer (TI2V uses dense 5B model, not MoE) # Load config from model to get correct dimensions transformer_config = load_transformer_config(model, "transformer", local_files_only) - self.transformer = create_transformer_from_config(transformer_config) + self.transformer = create_transformer_from_config( + transformer_config, quant_config=od_config.quantization_config + ) self._sample_solver = "unipc" self._flow_shift = od_config.flow_shift if od_config.flow_shift is not None else 5.0 diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py index 11408e2d24b..7ade2ac2fa4 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py @@ -40,9 +40,11 @@ logger = init_logger(__name__) -def create_vace_transformer_from_config(config: dict) -> WanVACETransformer3DModel: +def create_vace_transformer_from_config(config: dict, quant_config=None) -> WanVACETransformer3DModel: """Create WanVACETransformer3DModel from config dict.""" kwargs = {} + if quant_config is not None: + kwargs["quant_config"] = quant_config if "patch_size" in config: kwargs["patch_size"] = tuple(config["patch_size"]) if "num_attention_heads" in config: @@ -174,7 +176,7 @@ def __init__( def _create_transformer(self, config: dict) -> WanVACETransformer3DModel: """Build VACE transformer directly from config dict.""" - return create_vace_transformer_from_config(config) + return create_vace_transformer_from_config(config, quant_config=self.od_config.quantization_config) def diffuse( self, diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index d4d81b78eb8..d39d3bee5ae 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -32,6 +32,9 @@ from vllm_omni.diffusion.layers.norm import LayerNorm, RMSNorm from vllm_omni.platforms import current_omni_platform +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + logger = init_logger(__name__) @@ -100,7 +103,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ColumnParallelGELU(nn.Module): """Column parallel linear with GELU activation.""" - def __init__(self, dim_in: int, dim_out: int, *, approximate: str = "tanh", bias: bool = True): + def __init__( + self, + dim_in: int, + dim_out: int, + *, + approximate: str = "tanh", + bias: bool = True, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", + ): super().__init__() self.proj = ColumnParallelLinear( dim_in, @@ -108,6 +120,8 @@ def __init__(self, dim_in: int, dim_out: int, *, approximate: str = "tanh", bias bias=bias, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.proj", ) self.approximate = approximate @@ -128,12 +142,16 @@ def __init__( inner_dim: int, dim_out: int | None = None, bias: bool = True, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ) -> None: super().__init__() dim_out = dim_out or dim # ColumnParallel: scatter to each tp_rank - self.net_0 = ColumnParallelGELU(dim, inner_dim, approximate="tanh", bias=bias) + self.net_0 = ColumnParallelGELU( + dim, inner_dim, approximate="tanh", bias=bias, quant_config=quant_config, prefix=f"{prefix}.net_0" + ) # Placeholder for weight loading compatibility self.net_1 = nn.Identity() # RowParallel: gather from each tp_rank @@ -143,6 +161,8 @@ def __init__( bias=bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.net_2", ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -357,6 +377,8 @@ def __init__( head_dim: int, eps: float = 1e-5, dropout: float = 0.0, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() @@ -371,6 +393,8 @@ def __init__( head_size=head_dim, total_num_heads=num_heads, bias=True, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv", ) self.num_heads = self.to_qkv.num_heads @@ -391,6 +415,8 @@ def __init__( bias=True, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_out", ) self.dropout = nn.Dropout(dropout) @@ -409,6 +435,8 @@ def forward( rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, attn_mask: torch.Tensor | None = None, ) -> torch.Tensor: + # Ensure contiguous for FP8 quantized linear layers + hidden_states = hidden_states.contiguous() # Fused QKV projection qkv, _ = self.to_qkv(hidden_states) @@ -462,6 +490,8 @@ def __init__( eps: float = 1e-5, dropout: float = 0.0, added_kv_proj_dim: int | None = None, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() @@ -478,6 +508,8 @@ def __init__( bias=True, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_q", ) # Separate K and V projections for cross-attention @@ -487,6 +519,8 @@ def __init__( bias=True, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_k", ) self.to_v = ColumnParallelLinear( @@ -495,6 +529,8 @@ def __init__( bias=True, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_v", ) tp_size = get_tensor_model_parallel_world_size() @@ -518,6 +554,8 @@ def __init__( bias=True, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.add_k_proj", ) self.add_v_proj = ColumnParallelLinear( added_kv_proj_dim, @@ -525,6 +563,8 @@ def __init__( bias=True, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.add_v_proj", ) if get_tensor_model_parallel_world_size() > 1: self.norm_added_k = DistributedRMSNorm(self.tp_inner_dim, eps=eps) @@ -542,6 +582,8 @@ def __init__( bias=True, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_out", ) self.dropout = nn.Dropout(dropout) @@ -560,6 +602,9 @@ def forward( hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, ) -> torch.Tensor: + # Ensure contiguous for FP8 quantized linear layers + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() # Handle I2V case where encoder_hidden_states contains both image and text encoder_hidden_states_img = None if self.add_k_proj is not None: @@ -626,6 +671,8 @@ def __init__( eps: float = 1e-6, added_kv_proj_dim: int | None = None, cross_attn_norm: bool = False, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() @@ -638,6 +685,8 @@ def __init__( num_heads=num_heads, head_dim=head_dim, eps=eps, + quant_config=quant_config, + prefix=f"{prefix}.attn1", ) # 2. Cross-attention @@ -647,11 +696,15 @@ def __init__( head_dim=head_dim, eps=eps, added_kv_proj_dim=added_kv_proj_dim, + quant_config=quant_config, + prefix=f"{prefix}.attn2", ) self.norm2 = LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() # 3. Feed-forward - self.ffn = WanFeedForward(dim=dim, inner_dim=ffn_dim, dim_out=dim) + self.ffn = WanFeedForward( + dim=dim, inner_dim=ffn_dim, dim_out=dim, quant_config=quant_config, prefix=f"{prefix}.ffn" + ) self.norm3 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) # Scale-shift table for modulation @@ -807,6 +860,7 @@ def __init__( added_kv_proj_dim: int | None = None, rope_max_seq_len: int = 1024, pos_embed_seq_len: int | None = None, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -858,8 +912,17 @@ def __init__( # 3. Transformer blocks self.blocks = nn.ModuleList( [ - WanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, eps, added_kv_proj_dim, cross_attn_norm) - for _ in range(num_layers) + WanTransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + eps, + added_kv_proj_dim, + cross_attn_norm, + quant_config=quant_config, + prefix=f"blocks.{i}", + ) + for i in range(num_layers) ] ) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py index c48938e1baa..1fc8ebc754a 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -20,6 +20,9 @@ WanTransformerBlock, ) +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + class VaceWanTransformerBlock(WanTransformerBlock): """VACE variant of WanTransformerBlock with proj_in/proj_out for skip connections.""" @@ -33,8 +36,19 @@ def __init__( added_kv_proj_dim: int | None = None, cross_attn_norm: bool = False, block_id: int = 0, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ): - super().__init__(dim, ffn_dim, num_heads, eps, added_kv_proj_dim, cross_attn_norm) + super().__init__( + dim, + ffn_dim, + num_heads, + eps, + added_kv_proj_dim, + cross_attn_norm, + quant_config=quant_config, + prefix=prefix, + ) self.proj_in = nn.Linear(dim, dim) if block_id == 0 else None self.proj_out = nn.Linear(dim, dim) @@ -83,9 +97,10 @@ def __init__( *, vace_layers: list[int] | None = None, vace_in_channels: int | None = None, + quant_config: QuantizationConfig | None = None, **kwargs, ): - super().__init__(**kwargs) + super().__init__(quant_config=quant_config, **kwargs) self.vace_blocks = None self.vace_patch_embedding = None @@ -118,6 +133,8 @@ def __init__( self.config.added_kv_proj_dim, self.config.cross_attn_norm, block_id=i, + quant_config=quant_config, + prefix=f"vace_blocks.{i}", ) for i in range(len(vace_layers)) ] From 6cc547812378337288af1133b87b23f55a837ea9 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 05:32:42 +0800 Subject: [PATCH 20/24] [Quant] ModelOpt FP8 calibration script + stage config for Wan2.2 TI2V-5B examples/quantization/quantize_wan2_2_modelopt_fp8.py: Offline calibration helper that produces a ModelOpt FP8 diffusers checkpoint for Wan2.2 TI2V-5B (the dense 5B variant that fits 80GB BF16). Same design as the HunyuanVideo-1.5 calibrator (#2924): force-export FP8 weights, patch quant_algo: FP8 into config.json, hide quantizers during save. Skips Wan2.2's precision-sensitive layers (condition_embedder, patch_embedding, proj_out, scale_shift_table, SP helpers). MHA quantizers off by default. vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml: Stage config for serving the calibrated checkpoint via vllm-omni. Signed-off-by: lishunyang --- .../quantize_wan2_2_modelopt_fp8.py | 432 ++++++++++++++++++ .../stage_configs/wan2_2_ti2v_dit_fp8.yaml | 34 ++ 2 files changed, 466 insertions(+) create mode 100644 examples/quantization/quantize_wan2_2_modelopt_fp8.py create mode 100644 vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml diff --git a/examples/quantization/quantize_wan2_2_modelopt_fp8.py b/examples/quantization/quantize_wan2_2_modelopt_fp8.py new file mode 100644 index 00000000000..e489271b04f --- /dev/null +++ b/examples/quantization/quantize_wan2_2_modelopt_fp8.py @@ -0,0 +1,432 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantize Wan2.2 (TI2V-5B, 704x1280 T2V) to a ModelOpt FP8 Hugging Face checkpoint. + +Calibrates the DiT transformer using a small video prompt set and exports a +diffusers-style directory whose transformer carries ModelOpt FP8 metadata. +The exported checkpoint is consumable by vllm-omni's ModelOpt FP8 adapter +(see vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py). + +Layers kept full precision match the #2728 / #2795 pattern: condition embedder +(time/text/image), patch embedding, modulation (scale_shift_table), final +norm + proj_out, and sequence-parallel helpers. All attention + FFN linears +are quantized — static calibration handles the numerics that online FP8 +couldn't (see #2920 ablation). + +Default target is `Wan-AI/Wan2.2-TI2V-5B-Diffusers`, the dense 5B variant that +fits 80GB BF16. The A14B MoE variants need 2+ GPUs and are out of scope here. + +Example: + python examples/quantization/quantize_wan2_2_modelopt_fp8.py \\ + --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \\ + --output ./wan22-ti2v-modelopt-fp8 \\ + --overwrite +""" + +from __future__ import annotations + +import argparse +import copy +import json +import re +import shutil +import sys +from pathlib import Path +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +DEFAULT_PROMPTS = [ + "A dog running across a field of golden wheat.", + "An astronaut riding a horse across the surface of Mars, red dust swirling, cinematic wide shot.", + "A hummingbird hovering in front of a vibrant red flower, slow motion, macro shot.", + "A crackling campfire at night under a starry sky, sparks rising into the dark.", + "An underwater shot of a coral reef with tropical fish swimming by, sun rays piercing the water.", + "A close-up of a blooming rose covered in morning dew, soft natural light.", + "A peaceful mountain village at dawn, mist rolling over the rooftops, cinematic establishing shot.", + "A skateboarder doing a kickflip in an urban plaza, slow motion, golden hour lighting.", +] + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--model", required=True, help="Input Wan2.2 diffusers directory or HF id.") + p.add_argument("--output", required=True, help="Output directory for the ModelOpt FP8 checkpoint.") + p.add_argument("--dtype", choices=("bfloat16", "float16"), default="bfloat16") + p.add_argument("--height", type=int, default=704, help="Calibration video height (Wan2.2 TI2V-5B native: 704).") + p.add_argument("--width", type=int, default=1280, help="Calibration video width (Wan2.2 TI2V-5B native: 1280).") + p.add_argument( + "--num-frames", + type=int, + default=49, + help="Frames per calibration sample. 49 matches the typical short benchmark; " + "use 17 to reduce memory pressure during calibration.", + ) + p.add_argument("--guidance-scale", type=float, default=5.0) + p.add_argument( + "--calib-steps", + type=int, + default=10, + help="Denoising steps per calibration prompt (10 is enough for amax statistics).", + ) + p.add_argument("--calib-size", type=int, default=8, help="How many prompts to use for calibration.") + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--prompt", + action="append", + default=[], + help="Custom calibration prompt. Repeat to provide multiple.", + ) + p.add_argument( + "--quantize-mha", + action="store_true", + help="Enable FP8 attention K/V/softmax quantizers. Off by default — Wan2.2's long attention " + "sequences amplified FP8 drift in the online ablation (see #2920).", + ) + p.add_argument( + "--weight-block-size", + type=str, + default=None, + help="Per-block weight quantization as 'M,N' (e.g. '128,128'). Default per-tensor. " + "Note: vllm-omni's ModelOpt adapter may not yet dispatch block-wise scales — check #2924 " + "for the HV-1.5 investigation status before relying on this.", + ) + p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") + return p + + +def _parse_block_size(spec: str | None) -> list[int] | None: + if spec is None: + return None + parts = [int(x) for x in spec.split(",") if x.strip()] + if len(parts) != 2: + raise SystemExit(f"--weight-block-size must be 'M,N' (2 ints), got {spec!r}") + return parts + + +def _require_modelopt() -> Any: + try: + import modelopt.torch.quantization as mtq + except ModuleNotFoundError as exc: + raise SystemExit( + "NVIDIA ModelOpt is not installed. Install with:\n" + " pip install 'nvidia-modelopt[all]'\n" + f"Original error: {exc}" + ) from exc + return mtq + + +def _ensure_paths(args: argparse.Namespace) -> tuple[str, Path]: + model_path = args.model + output_dir = Path(args.output).expanduser().resolve() + if output_dir.exists(): + if not args.overwrite: + raise SystemExit(f"Output directory already exists: {output_dir}\nPass --overwrite to replace it.") + shutil.rmtree(output_dir) + return model_path, output_dir + + +def _select_dtype(name: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float16": torch.float16}[name] + + +def _build_prompts(args: argparse.Namespace) -> list[str]: + prompts = args.prompt or DEFAULT_PROMPTS + if args.calib_size <= 0: + raise SystemExit("--calib-size must be positive.") + if len(prompts) < args.calib_size: + repeats = (args.calib_size + len(prompts) - 1) // len(prompts) + prompts = (prompts * repeats)[: args.calib_size] + return prompts[: args.calib_size] + + +# Layers to KEEP at full precision. Wan2.2's module naming: +# - condition_embedder: time_embedder, time_proj, text_embedder, image_embedder (I2V) +# - patch_embedding: Conv3dLayer (already not Linear, belt-and-suspenders skip) +# - scale_shift_table: nn.Parameter modulation (not Linear, but pattern guard) +# - norm_out: AdaLayerNorm final +# - proj_out: final nn.Linear +# - timestep_proj_prepare / output_scale_shift_prepare: SP helpers +def _filter_func_wan22(name: str) -> bool: + pattern = re.compile( + r"(proj_out.*|" + r".*(condition_embedder|patch_embedding|" + r"norm_out|scale_shift_table|" + r"timestep_proj_prepare|output_scale_shift_prepare).*)" + ) + return pattern.match(name) is not None + + +def _mha_filter_func(name: str) -> bool: + pattern = re.compile( + r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer|bmm2_output_quantizer).*" + ) + return pattern.match(name) is not None + + +def _disable_known_problematic_quantizers(mtq: Any, backbone: torch.nn.Module, *, quantize_mha: bool) -> None: + if not hasattr(mtq, "disable_quantizer"): + return + mtq.disable_quantizer(backbone, _filter_func_wan22) + if not quantize_mha: + mtq.disable_quantizer(backbone, _mha_filter_func) + + +def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: + pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) + if hasattr(pipe, "set_progress_bar_config"): + pipe.set_progress_bar_config(disable=True) + pipe.to("cuda") + return pipe + + +def _build_forward_loop(pipe: DiffusionPipeline, args: argparse.Namespace, prompts: list[str]): + generator = torch.Generator(device="cuda") + + # Try setting guidance on the pipeline's guider if present (newer diffusers APIs). + guider = getattr(pipe, "guider", None) + if guider is not None and hasattr(guider, "guidance_scale"): + try: + guider.guidance_scale = args.guidance_scale + except Exception: + pass + + base_kwargs = dict( + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.calib_steps, + output_type="latent", + ) + + def forward_loop(*_unused_args, **_unused_kwargs) -> None: + with torch.inference_mode(): + for idx, prompt in enumerate(prompts): + generator.manual_seed(args.seed + idx) + # Try with guidance_scale first; fall back without on TypeError + # for pipelines that take CFG via guider config only. + try: + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **base_kwargs) + except TypeError as exc: + if "guidance_scale" not in str(exc): + raise + pipe(prompt=prompt, generator=generator, **base_kwargs) + + return forward_loop + + +def _summarize_export(output_dir: Path) -> None: + cfg_path = output_dir / "transformer" / "config.json" + if not cfg_path.exists(): + print(f"[warn] {cfg_path} missing.", file=sys.stderr) + return + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print("[warn] No quantization_config in transformer/config.json.", file=sys.stderr) + return + print("Export summary:") + print(f" quant_method: {qc.get('quant_method')}") + print(f" quant_algo: {qc.get('quant_algo')}") + producer = qc.get("producer") + if isinstance(producer, dict): + print(f" producer: {producer.get('name')} {producer.get('version')}") + print(f" config path: {cfg_path}") + + +def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtype) -> int: + """Convert in-memory weights of quantized modules to actual FP8 storage. + + `export_hf_checkpoint` skips this step for unknown model types (Wan2.2 isn't + in ModelOpt's recognized-model registry), so we must call the per-weight + export helper ourselves. Same workaround as the HunyuanVideo-1.5 / HunyuanImage-3 + calibration helpers. + """ + from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + get_quantization_format, + quantizer_attr_names, + weight_attr_names, + ) + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + exported = 0 + for name, module in backbone.named_modules(): + try: + quantization_format = get_quantization_format(module) + except Exception as exc: + print(f"[warn] Could not inspect quantization format for {name}: {exc}", file=sys.stderr) + continue + if quantization_format == QUANTIZATION_NONE: + continue + for weight_name in weight_attr_names(module): + quantizer_attrs = quantizer_attr_names(weight_name) + weight_quantizer = getattr(module, quantizer_attrs.weight_quantizer, None) + if weight_quantizer is None or not getattr(weight_quantizer, "is_enabled", False): + continue + _export_quantized_weight(module, dtype, weight_name) + exported += 1 + return exported + + +def _wan22_quant_config_block(weight_block_size: list[int] | None = None) -> dict: + """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913).""" + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} + if weight_block_size is not None: + weights_cfg["strategy"] = "block" + weights_cfg["block_structure"] = f"{weight_block_size[0]}x{weight_block_size[1]}" + return { + "config_groups": { + "group_0": { + "input_activations": {"dynamic": False, "num_bits": 8, "type": "float"}, + "weights": weights_cfg, + "targets": ["Linear"], + } + }, + "ignore": [ + "condition_embedder*", + "norm_out*", + "output_scale_shift_prepare*", + "patch_embedding*", + "proj_out*", + "scale_shift_table*", + "timestep_proj_prepare*", + ], + "producer": {"name": "modelopt"}, + "quant_algo": "FP8", + "quant_method": "modelopt", + } + + +def _patch_quant_config(output_dir: Path, weight_block_size: list[int] | None = None) -> None: + """Inject quant_algo: FP8 + config_groups into transformer/config.json so + vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8.""" + cfg_path = output_dir / "transformer" / "config.json" + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + new_qc = _wan22_quant_config_block(weight_block_size=weight_block_size) + existing = cfg.get("quantization_config") + if isinstance(existing, dict): + producer = existing.get("producer") + if isinstance(producer, dict): + new_qc["producer"] = producer + + cfg["quantization_config"] = new_qc + with cfg_path.open("w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2) + + +def _save_pipeline_with_fp8_transformer( + pipe: DiffusionPipeline, + model_path: str, + output_dir: Path, + max_shard_size: str = "5GB", +) -> None: + """Copy source dir verbatim minus transformer/, then save the quantized transformer.""" + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict + + src = Path(model_path) + if not src.exists(): + from huggingface_hub import snapshot_download + + src = Path(snapshot_download(model_path)) + + if output_dir.exists(): + shutil.rmtree(output_dir) + shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer", "transformer_2")) + + transformer_out = output_dir / "transformer" + # Pass the nn.Module (transformer), not the Pipeline wrapper. + with hide_quantizers_from_state_dict(pipe.transformer): + pipe.transformer.save_pretrained( + str(transformer_out), + safe_serialization=True, + max_shard_size=max_shard_size, + ) + + +def main() -> None: + args = _build_parser().parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for ModelOpt FP8 quantization.") + + mtq = _require_modelopt() + model_path, output_dir = _ensure_paths(args) + dtype = _select_dtype(args.dtype) + prompts = _build_prompts(args) + weight_block_size = _parse_block_size(args.weight_block_size) + + print("Quantization plan:") + print(f" input: {args.model}") + print(f" output: {output_dir}") + print(f" dtype: {dtype}") + print(f" height/width: {args.height}x{args.width}") + print(f" num_frames: {args.num_frames}") + print(f" calib_size: {len(prompts)}") + print(f" calib_steps: {args.calib_steps}") + print(f" quantize_mha: {args.quantize_mha}") + print( + f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" + ) + + pipe = _load_pipeline(model_path, dtype) + backbone = pipe.transformer + + quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + if weight_block_size is not None: + quant_config["quant_cfg"]["*weight_quantizer"] = { + "num_bits": (4, 3), + "block_sizes": {-1: weight_block_size[1], -2: weight_block_size[0]}, + } + print( + f" -> overriding weight quantizer with block_sizes={weight_block_size} " + f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" + ) + + forward_loop = _build_forward_loop(pipe, args, prompts) + quantized = mtq.quantize(backbone, quant_config, forward_loop) + if quantized is not None: + pipe.transformer = quantized + backbone = quantized + + _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=args.quantize_mha) + + print("\nForcing FP8 weight serialization (Wan2.2 isn't in ModelOpt's recognized-model registry,") + print("so we have to call the per-weight export helper ourselves)...") + exported = _force_export_quantized_weights(backbone, dtype) + print(f" -> {exported} weights converted to FP8 in memory") + if exported == 0: + raise SystemExit( + "No quantized weights were exported. Calibration may have skipped every layer " + "(check the disable_quantizer regex) or `mtq.quantize` did not actually wrap any " + "weight quantizers." + ) + + print("\nSaving pipeline with FP8 transformer...") + _save_pipeline_with_fp8_transformer(pipe, model_path, output_dir) + _patch_quant_config(output_dir, weight_block_size=weight_block_size) + print(f"Saved to: {output_dir}") + _summarize_export(output_dir) + + print("\nNext: validate the checkpoint with vllm-omni:") + print( + " python examples/offline_inference/text_to_video/text_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 --seed 42 \\\n" + " --output outputs/wan22_modelopt_fp8.mp4" + ) + print( + "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " + "checkpoint's config.json has modelopt metadata.)" + ) + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml b/vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml new file mode 100644 index 00000000000..b6291073f12 --- /dev/null +++ b/vllm_omni/model_executor/stage_configs/wan2_2_ti2v_dit_fp8.yaml @@ -0,0 +1,34 @@ +# Stage config for running Wan2.2 TI2V-5B DiT with ModelOpt FP8 auto-detect. +# Single GPU (TI2V-5B fits 80GB BF16; FP8 drops by ~half). +# For the A14B MoE variants, bump `tensor_parallel_size` and `devices`. +# +# Use with a ModelOpt FP8 checkpoint (e.g. produced by +# examples/quantization/quantize_wan2_2_modelopt_fp8.py). + +stage_args: + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: dit + model_class_name: Wan22TI2VPipeline + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + distributed_executor_backend: "mp" + parallel_config: + tensor_parallel_size: 1 + + final_output: true + final_output_type: video + is_comprehension: false + default_sampling_params: + seed: 42 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 From 4e6f5574f07179c0b82ab7e5eb742b9b686ef825 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 05:57:12 +0800 Subject: [PATCH 21/24] [Quant] Fix Wan2.2 ModelOpt FP8 FFN scale-name remap (ffn.net.0 -> ffn.net_0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wan2.2 ModelOpt FP8 checkpoint has diffusers-style dotted FFN names (ffn.net.0.proj, ffn.net.2) but vllm-omni's WanFeedForward uses underscored names (ffn.net_0.proj, ffn.net_2). The transformer's load_weights remaps these for .weight tensors, but the ModelOpt adapter resolves scale tensor names independently via WeightsMapper and was missing the remap — all 120 FFN scale tensors (30 blocks x 2 linears x 2 scales) silently fell through, leaving FP8 weights with no valid scales at serving time (visible as pure noise output). Fix: - Add hf_to_vllm_mapper class attribute on WanTransformer3DModel with the ffn remap. - Extend ModelOptFp8CheckpointAdapter._get_weights_mapper to merge a model's hf_to_vllm_mapper (if present) into the resolution map. Models can now register arbitrary substring remaps via this standard vLLM attribute. Signed-off-by: lishunyang --- .../model_loader/checkpoint_adapters/modelopt_fp8.py | 8 ++++++++ .../diffusion/models/wan2_2/wan2_2_transformer.py | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py index 64661a15f29..44b75d5cb60 100644 --- a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py +++ b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py @@ -105,6 +105,14 @@ def _get_weights_mapper(cls, model: nn.Module) -> WeightsMapper: for shard_name in shard_names: orig_to_new_substr[f".{shard_name}."] = f".{packed_name}." orig_to_new_prefix[f"{shard_name}."] = f"{packed_name}." + + # Let models extend the remap with arbitrary diffusers→vllm-omni + # substring translations (e.g. Wan2.2's `.ffn.net.0.` → `.ffn.net_0.`). + model_mapper = getattr(model, "hf_to_vllm_mapper", None) + if model_mapper is not None: + orig_to_new_substr.update(getattr(model_mapper, "orig_to_new_substr", None) or {}) + orig_to_new_prefix.update(getattr(model_mapper, "orig_to_new_prefix", None) or {}) + return WeightsMapper( orig_to_new_substr=orig_to_new_substr, orig_to_new_prefix=orig_to_new_prefix, diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index d39d3bee5ae..591c1421e20 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.conv import Conv3dLayer from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import WeightsMapper from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention @@ -794,6 +795,17 @@ class WanTransformer3DModel(nn.Module): "to_qkv": ["to_q", "to_k", "to_v"], } + # Diffusers → vllm-omni weight-name translation. Used by the ModelOpt FP8 + # checkpoint adapter (#2913) to remap scale tensor names: diffusers stores + # FFN as `ffn.net.0.proj` / `ffn.net.2` (dotted), vllm-omni's WanFeedForward + # uses `ffn.net_0.proj` / `ffn.net_2` (underscore). + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + ".ffn.net.0.": ".ffn.net_0.", + ".ffn.net.2.": ".ffn.net_2.", + }, + ) + @staticmethod def _is_transformer_block(name: str, module) -> bool: """Match transformer blocks for HSDP sharding (e.g., blocks.0, blocks.1).""" From a5fb789001b454618b7ecdaf3d419a05d0252e79 Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 06:16:09 +0800 Subject: [PATCH 22/24] [Quant] ModelOpt FP8 adapter: log first 3 skipped scales for diagnostics Helps diagnose name-mismatch between checkpoint keys and model parameters (e.g. diffusers .ffn.net.0. vs vllm-omni .ffn.net_0.). Signed-off-by: lishunyang --- .../checkpoint_adapters/modelopt_fp8.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py index 44b75d5cb60..c7524a3dd9b 100644 --- a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py +++ b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py @@ -183,6 +183,19 @@ def _handle_scale_tensor( state.scale_tensors[name] = tensor if target_name is None: state.skipped_scales += 1 + # Diagnostic: log first few skipped scales with a sample of loadable keys that + # contain a similar substring so we can tell if the hf_to_vllm_mapper is even loaded. + if state.skipped_scales <= 3: + hint_substr = name.rsplit(".", 2)[0] if "." in name else name + similar = [k for k in self._loadable_tensors if k.endswith(name.split(".")[-1])][:3] + logger.warning( + "ModelOpt FP8 adapter: skipping scale %r (no target). " + "Similar loadable params by suffix: %r. " + "Hint: the checkpoint key uses a name that doesn't match any model parameter. " + "Check hf_to_vllm_mapper on the model class.", + name, + similar, + ) else: yield name, tensor yield from self._flush_pending_weights(name, state) From 7e57b69053c21346cdf1741c59e6b7310431755e Mon Sep 17 00:00:00 2001 From: lishunyang Date: Mon, 20 Apr 2026 06:20:34 +0800 Subject: [PATCH 23/24] [Quant] Walk submodules when collecting hf_to_vllm_mapper for ModelOpt FP8 adapter The adapter is instantiated with the whole Pipeline, not just the DiT. Only checking the top-level model means hf_to_vllm_mapper defined on a sub-module (e.g. WanTransformer3DModel inside Wan22TI2VPipeline) was invisible. Walk named_modules() and aggregate any mappers found. Signed-off-by: lishunyang --- .../checkpoint_adapters/modelopt_fp8.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py index c7524a3dd9b..b3580c70f0e 100644 --- a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py +++ b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py @@ -106,12 +106,17 @@ def _get_weights_mapper(cls, model: nn.Module) -> WeightsMapper: orig_to_new_substr[f".{shard_name}."] = f".{packed_name}." orig_to_new_prefix[f"{shard_name}."] = f"{packed_name}." - # Let models extend the remap with arbitrary diffusers→vllm-omni - # substring translations (e.g. Wan2.2's `.ffn.net.0.` → `.ffn.net_0.`). - model_mapper = getattr(model, "hf_to_vllm_mapper", None) - if model_mapper is not None: - orig_to_new_substr.update(getattr(model_mapper, "orig_to_new_substr", None) or {}) - orig_to_new_prefix.update(getattr(model_mapper, "orig_to_new_prefix", None) or {}) + # Collect `hf_to_vllm_mapper` from `model` AND any submodule that defines one. + # The adapter may be called with a whole Pipeline; its transformer submodule + # (e.g. WanTransformer3DModel) is where the model-specific mapper lives. + collected: set[int] = set() + for m in (model, *(sm for _, sm in model.named_modules())): + mp = getattr(m, "hf_to_vllm_mapper", None) + if mp is None or id(mp) in collected: + continue + collected.add(id(mp)) + orig_to_new_substr.update(getattr(mp, "orig_to_new_substr", None) or {}) + orig_to_new_prefix.update(getattr(mp, "orig_to_new_prefix", None) or {}) return WeightsMapper( orig_to_new_substr=orig_to_new_substr, From 2b543e91415d26ea458173262c8ff726978c8d74 Mon Sep 17 00:00:00 2001 From: Shuaiwei Huang <112870732+ArtificialRay@users.noreply.github.com> Date: Sun, 17 May 2026 10:21:21 -0700 Subject: [PATCH 24/24] Phase1 (video-gen) ModelOpt FP8 Follow-ups (#57) * update quant wan2-2 modelopt to support A14B model * update wan2-2 modelopt quant script * update two gpu quantization quality script * update i2v modelopt quant script * update hunyuanvideo and wan2.2 vace modelopt script * add quantization config parsing image2video script * update vae-use-tiling for quantization quality script to avoid cuda oom for bf16 model * update vae-use-tiling for quantization quality script to avoid cuda oom for fp8 model(for vae) * update quantization quality script to support i2v videogen task * fix modelopt fp8 quantization script and quality script in T2V * update per-block quant * update vace videogen script * update quantization quality script to support model load and throughput calculation, and rewrite quant_quality script to automate model offline quant * fix quantization quality script in hunyuanvideo1.5 * update modelopt check script * update remote transmisson to bench_quant_videogen * update check_quant_videogen * update bench quant videogen script * update quality bench scripts to add negative prompt to wan2.2 I2V * update quality bench script for wan2.2 i2v * update quality bench script to add denoise throughput(s/it) * quant_quality script update for image gen model * del unrelative scripts * del unrelative scripts * update recommend test cmd after quantization for wan models Signed-off-by: ArtificialRay --------- Signed-off-by: ArtificialRay --- benchmarks/diffusion/quantization_quality.py | 214 +++++- .../image_to_video/image_to_video.py | 36 +- .../vace/vace_video_generation.py | 8 + .../quantization/check_modelopt_fp8_export.py | 102 ++- .../quantize_hunyuanvideo_15_modelopt_fp8.py | 208 +++++- .../quantize_wan2_2_modelopt_fp8.py | 434 +++++++++-- .../quantize_wan2_2_vace_modelopt_fp8.py | 701 ++++++++++++++++++ .../checkpoint_adapters/modelopt_fp8.py | 7 +- 8 files changed, 1549 insertions(+), 161 deletions(-) create mode 100644 examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py diff --git a/benchmarks/diffusion/quantization_quality.py b/benchmarks/diffusion/quantization_quality.py index 4a916e7ea62..978b69d5c5d 100644 --- a/benchmarks/diffusion/quantization_quality.py +++ b/benchmarks/diffusion/quantization_quality.py @@ -34,6 +34,33 @@ --height 720 --width 1280 \ --num-frames 81 --num-inference-steps 40 --seed 42 +Video example (text-to-video) with offline quant: + python benchmarks/diffusion/quantization_quality.py \ + --use-offline-quant \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --model-quant-checkpoint /vllm-omni/wan22-t2v-modelopt-fp8\ + --task t2v \ + --quantization fp8 \ + --prompts \ + "A serene lakeside sunrise with mist over the water" \ + "A cat walking across a wooden bridge in autumn" \ + --height 720 --width 1280 \ + --num-frames 81 --num-inference-steps 40 --seed 42 + +Video example (image-to-video) with offline quant: + python benchmarks/diffusion/quantization_quality.py \ + --use-offline-quant \ + --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --model-quant-checkpoint /vllm-omni/wan22-i2v-modelopt-fp8\ + --task i2v \ + --quantization fp8 \ + --prompts \ + "An astronaut riding a horse across the surface of Mars, red dust swirling, cinematic wide shot." \ + "A skateboarder doing a kickflip in an urban plaza, slow motion, golden hour lighting." \ + --height 720 --width 1280 \ + --image /path/to/ref_images/ \ + --num-frames 81 --num-inference-steps 40 --seed 42 --vae-use-tiling + Multiple quantization methods: python benchmarks/diffusion/quantization_quality.py \ --model Tongyi-MAI/Z-Image-Turbo \ @@ -54,7 +81,7 @@ import gc import time from pathlib import Path - +from typing import Any import numpy as np import torch @@ -135,18 +162,60 @@ def _build_omni_kwargs(args, quantization=None): ring_degree=args.ring_degree, tensor_parallel_size=args.tensor_parallel_size, ) - kwargs = { - "model": args.model, - "parallel_config": parallel_config, - "enforce_eager": args.enforce_eager, - } if quantization: - kwargs["quantization_config"] = quantization + kwargs = { + "model": args.model_quant_checkpoint if args.use_offline_quant else args.model, + "parallel_config": parallel_config, + "enforce_eager": args.enforce_eager, + "quantization_config":quantization, + "vae_use_tiling":args.vae_use_tiling, + "enable_diffusion_pipeline_profiler": True, + } + else: + kwargs = { + "model": args.model, + "parallel_config": parallel_config, + "enforce_eager": args.enforce_eager, + "vae_use_tiling":args.vae_use_tiling, + "enable_diffusion_pipeline_profiler": True, + } + return kwargs +def _load_reference_images(spec: str | None) -> list[Any]: + """Load PIL.Image list from a directory or a single file path.""" + if spec is None: + return [] + from PIL import Image + + p = Path(spec).expanduser() + if not p.exists(): + raise SystemExit(f"--reference-images path not found: {p}") + if p.is_file(): + return [Image.open(p).convert("RGB")] + image_paths = sorted( + f for f in p.iterdir() if f.is_file() and f.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") + ) + if not image_paths: + raise SystemExit(f"No image files (jpg/jpeg/png/webp) found in {p}") + return [Image.open(f).convert("RGB") for f in image_paths] + +def _extract_denoise_seconds(first) -> float | None: + """Pull denoise-only time out of OmniRequestOutput.stage_durations. + + The pipeline profiler records keys like '.diffuse' + (cuda-synced timer wrapped around the denoising loop). Match by suffix so + this stays model-agnostic. + """ + stage_durations = getattr(first, "stage_durations", {}) or {} + return next( + (v for k, v in stage_durations.items() if k.endswith(".diffuse")), + None, + ) + def _generate_image(omni, args, prompt, seed): - """Generate a single image and return (PIL.Image, time_seconds, memory_gib).""" + """Generate a single image and return (PIL.Image, wall_seconds, memory_gib, denoise_seconds).""" from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform @@ -166,22 +235,27 @@ def _generate_image(omni, args, prompt, seed): peak_mem = torch.cuda.max_memory_allocated() / (1024**3) first = outputs[0] - req_out = first.request_output[0] if hasattr(first, "request_output") else first + peak_mem = getattr(first, "peak_memory_mb", 0.0) / 1024 # MB -> GiB + denoise_seconds = _extract_denoise_seconds(first) + req_out = first.request_output if hasattr(first, "request_output") else first img = req_out.images[0] - return img, elapsed, peak_mem + return img, elapsed, peak_mem, denoise_seconds -def _generate_video(omni, args, prompt, seed): - """Generate a video and return (np.ndarray [F,H,W,C], time_seconds, memory_gib).""" +def _generate_video(omni, args, prompt, seed,image=None): + """Generate a video and return (np.ndarray [F,H,W,C], wall_seconds, memory_gib, denoise_seconds).""" from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform + request = {"prompt": prompt, "negative_prompt": args.negative_prompt} + if image is not None: + request["multi_modal_data"] = {"image": image} generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) torch.cuda.reset_peak_memory_stats() start = time.perf_counter() outputs = omni.generate( - {"prompt": prompt, "negative_prompt": ""}, + request, OmniDiffusionSamplingParams( height=args.height, width=args.width, @@ -192,9 +266,10 @@ def _generate_video(omni, args, prompt, seed): ), ) elapsed = time.perf_counter() - start - peak_mem = torch.cuda.max_memory_allocated() / (1024**3) first = outputs[0] + peak_mem = getattr(first, "peak_memory_mb", 0.0) / 1024 # MB -> GiB + denoise_seconds = _extract_denoise_seconds(first) if hasattr(first, "request_output") and isinstance(first.request_output, list): inner = first.request_output[0] if isinstance(inner, OmniRequestOutput) and hasattr(inner, "images"): @@ -217,10 +292,12 @@ def _generate_video(omni, args, prompt, seed): frames_array = video.float().numpy() else: frames_array = np.asarray(frames) + if frames_array.ndim ==6:# wan2.2: inner.images = [ndarray[1,F,H,W,C]] + frames_array = frames_array[0] if frames_array.ndim == 5: frames_array = frames_array[0] - return frames_array, elapsed, peak_mem + return frames_array, elapsed, peak_mem, denoise_seconds def _unload_omni(omni): @@ -238,9 +315,11 @@ def run_benchmark(args): output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - is_video = args.task == "t2v" + is_video = args.task in ("t2v", "i2v", "ti2v") + is_image_conditioned = args.task == "i2v" or (args.task == "ti2v" and args.images is not None) prompts = args.prompts seed = args.seed + input_images = _load_reference_images(args.images) if is_image_conditioned else None # Determine configs to benchmark configs = [] # list of (label, quantization_method) @@ -254,18 +333,23 @@ def run_benchmark(args): bl_kwargs = _build_omni_kwargs(args, quantization=None) omni_bl = Omni(**bl_kwargs) - baseline_outputs = {} # prompt -> (output, time, mem) - for prompt in prompts: + baseline_outputs = {} # prompt -> (output, wall_time, mem, denoise_seconds) + for i,prompt in enumerate(prompts): print(f" Generating: {prompt[:60]}...") - if is_video: - out, t, mem = _generate_video(omni_bl, args, prompt, seed) + if is_video and input_images is not None: + out, t, mem, dt = _generate_video(omni_bl, args, prompt, seed,input_images[i % len(input_images)]) + elif is_video: + out, t, mem, dt = _generate_video(omni_bl, args, prompt, seed) else: - out, t, mem = _generate_image(omni_bl, args, prompt, seed) - baseline_outputs[prompt] = (out, t, mem) + out, t, mem, dt = _generate_image(omni_bl, args, prompt, seed) + baseline_outputs[prompt] = (out, t, mem, dt) bl_avg_time = np.mean([v[1] for v in baseline_outputs.values()]) + bl_denoise_times = [v[3] for v in baseline_outputs.values() if v[3] is not None] + bl_avg_denoise = float(np.mean(bl_denoise_times)) if bl_denoise_times else None bl_mem = baseline_outputs[prompts[0]][2] # use first prompt's memory _unload_omni(omni_bl) + del omni_bl # must explicitly del omni_bl otherwise will cause oom when running next model # Save baseline outputs bl_dir = output_dir / "baseline" @@ -295,18 +379,22 @@ def run_benchmark(args): omni_qt = Omni(**qt_kwargs) qt_outputs = {} - for prompt in prompts: + for i,prompt in enumerate(prompts): print(f" Generating: {prompt[:60]}...") - if is_video: - out, t, mem = _generate_video(omni_qt, args, prompt, seed) + if is_video and input_images is not None: + out, t, mem, dt = _generate_video(omni_qt, args, prompt, seed,input_images[i % len(input_images)]) + elif is_video: + out, t, mem, dt = _generate_video(omni_qt, args, prompt, seed) else: - out, t, mem = _generate_image(omni_qt, args, prompt, seed) - qt_outputs[prompt] = (out, t, mem) + out, t, mem, dt = _generate_image(omni_qt, args, prompt, seed) + qt_outputs[prompt] = (out, t, mem, dt) qt_avg_time = np.mean([v[1] for v in qt_outputs.values()]) + qt_denoise_times = [v[3] for v in qt_outputs.values() if v[3] is not None] + qt_avg_denoise = float(np.mean(qt_denoise_times)) if qt_denoise_times else None qt_mem = qt_outputs[prompts[0]][2] _unload_omni(omni_qt) - + del omni_qt # must explicitly del omni_qt otherwise will cause oom when running next model # Save quantized outputs qt_dir = output_dir / config_label.replace(" ", "_") qt_dir.mkdir(parents=True, exist_ok=True) @@ -333,12 +421,18 @@ def run_benchmark(args): mean_lpips = np.mean([p["lpips"] for p in per_prompt]) speedup = bl_avg_time / qt_avg_time if qt_avg_time > 0 else float("inf") mem_reduction = (bl_mem - qt_mem) / bl_mem * 100 + # Throughput uses denoise-only time (cuda-synced via DiffusionPipelineProfiler); + # falls back to wall time only if the profiler didn't surface it. + qt_throughput_basis = qt_avg_denoise if qt_avg_denoise is not None else qt_avg_time + qt_throughput = qt_throughput_basis / args.num_inference_steps if args.num_inference_steps > 0 else float("inf") all_results.append( { "config": config_label, "avg_time": qt_avg_time, + "avg_denoise": qt_avg_denoise, "speedup": speedup, + "throughput_its": qt_throughput, "memory_gib": qt_mem, "mem_reduction_pct": mem_reduction, "mean_lpips": mean_lpips, @@ -353,8 +447,9 @@ def run_benchmark(args): print("=" * 80) # Summary table + model = args.model_quant_checkpoint if args.use_offline_quant else args.model lines = [] - lines.append(f"## Quantization Quality Benchmark — {args.model.split('/')[-1]}") + lines.append(f"## Quantization Quality Benchmark — {model.split('/')[-1]}") lines.append( f"Setup: {args.height}x{args.width}, {args.num_inference_steps} steps, " f"seed={args.seed}, LPIPS ({args.lpips_net})" @@ -364,17 +459,42 @@ def run_benchmark(args): lines.append("") lines.append("### Summary") lines.append("") - lines.append("| Config | Avg Time | Speedup | Memory (GiB) | Mem Reduction | Mean LPIPS |") - lines.append("|--------|----------|---------|--------------|---------------|------------|") - lines.append(f"| BF16 baseline | {bl_avg_time:.2f}s | 1.00x | {bl_mem:.2f} | — | (ref) |") + # Throughput uses denoise-only time (cuda-synced via DiffusionPipelineProfiler); + # falls back to wall time only if the profiler didn't surface it. + bl_throughput_basis = bl_avg_denoise if bl_avg_denoise is not None else bl_avg_time + bl_throughput = bl_throughput_basis / args.num_inference_steps if args.num_inference_steps > 0 else float("inf") + lines.append( + "| Config | Avg Time | Speedup | Throughput (s/it) " + "| Peak VRAM (GiB) | Peak VRAM Reduction | Mean LPIPS |" + ) + lines.append( + "|--------|----------|---------|--------------------" + "|-----------------|---------------------|------------|" + ) + lines.append( + f"| BF16 baseline | {bl_avg_time:.2f}s | 1.00x | {bl_throughput:.3f} " + f"| {bl_mem:.2f} | — | (ref) |" + ) for r in all_results: lines.append( f"| {r['config']} | {r['avg_time']:.2f}s | {r['speedup']:.2f}x " + f"| {r['throughput_its']:.3f} " f"| {r['memory_gib']:.2f} | {r['mem_reduction_pct']:.0f}% " f"| {r['mean_lpips']:.4f} |" ) lines.append("") lines.append("> LPIPS < 0.01 = imperceptible, > 0.1 = clearly noticeable.") + lines.append( + "> Throughput (s/it) = denoise time / num_inference_steps " + "(cuda-synced via DiffusionPipelineProfiler; excludes text encode + VAE decode)." + ) + lines.append( + "> Peak VRAM = `max_memory_allocated` during one generate (model + activations + latents + VAE buffers)." + ) + lines.append( + "> For model resident size on disk/VRAM, run `examples/quantization/check_modelopt_fp8_export.py " + "--output --baseline `." + ) lines.append("") # Per-prompt table @@ -415,11 +535,13 @@ def parse_args(): formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--model", required=True, help="Model name or local path.") + parser.add_argument("--model-quant-checkpoint",default=None, help="Offline quantization model checkpoint") + parser.add_argument("--use-offline-quant",action="store_true",help="compare with offline quantized model checkpoint") parser.add_argument( "--task", default="t2i", - choices=["t2i", "t2v"], - help="Task type: t2i (text-to-image) or t2v (text-to-video).", + choices=["t2i", "t2v", "i2v", "ti2v"], + help="Task type: t2i (text-to-image), t2v (text-to-video), i2v / ti2v (image-to-video).", ) parser.add_argument( "--quantization", @@ -433,6 +555,11 @@ def parse_args(): default=["a cup of coffee on the table"], help="One or more prompts to generate.", ) + parser.add_argument( + "--images", + default=None, + help="Path to input images (required for i2v and ti2v tasks).", + ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--height", type=int, default=1024) parser.add_argument("--width", type=int, default=1024) @@ -440,6 +567,13 @@ def parse_args(): parser.add_argument("--num-frames", type=int, default=81, help="Number of video frames (t2v only).") parser.add_argument("--fps", type=int, default=24, help="Video FPS for saving (t2v only).") parser.add_argument("--guidance-scale", type=float, default=4.0, help="CFG scale (used for video).") + parser.add_argument( + "--negative-prompt", + type=str, + default="", + help="Negative prompt for video generation. Wan2.2 I2V degenerates to a static frame " + "without the official anti-static negative prompt; other pipelines work with empty.", + ) parser.add_argument("--output-dir", type=str, default="./quant_bench_output", help="Directory to save outputs.") parser.add_argument( "--lpips-net", @@ -452,7 +586,17 @@ def parse_args(): parser.add_argument("--ring-degree", type=int, default=1) parser.add_argument("--tensor-parallel-size", type=int, default=1) parser.add_argument("--enforce-eager", action="store_true") - return parser.parse_args() + parser.add_argument( + "--vae-use-tiling", + action="store_true", + help="Enable VAE tiling for memory optimization.Specifically for bf16 model", + ) + args = parser.parse_args() + if args.task in ("i2v") and args.images is None: + parser.error(f"--task {args.task} requires --images") + if args.use_offline_quant and not args.model_quant_checkpoint: + parser.error("--use-offline-quant requires --model-quant-checkpoint") + return args if __name__ == "__main__": diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 53319c82211..3b8c1956663 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -36,6 +36,7 @@ import os import time from pathlib import Path +from typing import Any import numpy as np import PIL.Image @@ -194,6 +195,23 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable diffusion pipeline profiler to display stage durations.", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + choices=["fp8"], + help="Quantization method for the transformer. " + "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). " + "Default: None (no quantization, uses BF16).", + ) + parser.add_argument( + "--ignored-layers", + type=str, + default=None, + help="Comma-separated list of layer name patterns to skip quantization. " + "Only used when --quantization is set. " + "Example: --ignored-layers 'to_qkv,to_out'", + ) return parser.parse_args() @@ -286,6 +304,18 @@ def main(): hsdp_shard_size=args.hsdp_shard_size, hsdp_replicate_size=args.hsdp_replicate_size, ) + + # Build quantization kwargs + quant_kwargs: dict[str, Any] = {} + ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None + if args.quantization and ignored_layers: + quant_kwargs["quantization_config"] = { + "method": args.quantization, + "ignored_layers": ignored_layers, + } + elif args.quantization: + quant_kwargs["quantization"] = args.quantization + omni = Omni( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, @@ -300,6 +330,7 @@ def main(): cache_backend=args.cache_backend, cache_config=cache_config, enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, + **quant_kwargs, ) if profiler_enabled: @@ -317,6 +348,9 @@ def main(): f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}," f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}" ) + print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}") + if ignored_layers: + print(f" Ignored layers: {ignored_layers}") print(f" Video size: {args.width}x{args.height}") print(f"{'=' * 60}\n") @@ -538,4 +572,4 @@ def _ensure_frame_list(video_array): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/examples/offline_inference/vace/vace_video_generation.py b/examples/offline_inference/vace/vace_video_generation.py index 6ca0d74c52e..e183e0d5b58 100644 --- a/examples/offline_inference/vace/vace_video_generation.py +++ b/examples/offline_inference/vace/vace_video_generation.py @@ -52,6 +52,13 @@ def parse_args() -> argparse.Namespace: choices=["t2v", "i2v", "v2lf", "flf2v", "inpaint", "r2v"], help="Generation mode.", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + choices=["fp8", "gguf"], + help="Quantization method for the transformer (fp8 for online FP8 quantization).", + ) parser.add_argument("--prompt", default="A cat walking in a garden", help="Text prompt.") parser.add_argument("--negative-prompt", default="", help="Negative prompt.") parser.add_argument("--image", type=str, default=None, help="Input image path (for I2V, R2V, FLF2V, inpaint).") @@ -159,6 +166,7 @@ def main(): flow_shift=args.flow_shift, enforce_eager=args.enforce_eager, parallel_config=parallel_config, + quantization = args.quantization ) prompt_data = build_prompts(args) diff --git a/examples/quantization/check_modelopt_fp8_export.py b/examples/quantization/check_modelopt_fp8_export.py index eafa3b62fb2..bab1211a6a1 100644 --- a/examples/quantization/check_modelopt_fp8_export.py +++ b/examples/quantization/check_modelopt_fp8_export.py @@ -49,15 +49,46 @@ def _check_config(transformer_dir: Path) -> int: issues = [] if qc.get("quant_method") != "modelopt": issues.append(f"quant_method={qc.get('quant_method')!r} (expected 'modelopt')") - if qc.get("quant_algo") != "FP8": - issues.append(f"quant_algo={qc.get('quant_algo')!r} (expected 'FP8' — vllm-omni adapter may not auto-detect)") + + quant_algo = qc.get("quant_algo") + if quant_algo not in ("FP8", "FP8_PB_WO"): + issues.append( + f"quant_algo={quant_algo!r} (expected 'FP8' for per-tensor or " + "'FP8_PB_WO' for 128x128 block-wise — other algos aren't routed by " + "vllm-omni's adapter today)" + ) + + # Cross-check that the saved weight strategy and the dispatch field agree. + # Producer scripts can in principle drift apart (e.g. metadata says "block" + # but quant_algo still claims "FP8"), and that lands as an AssertionError at + # weight load time because the runtime LinearMethod expects scalar scales but + # finds 4D block ones. Failing here is much friendlier. + cfg_groups = qc.get("config_groups", {}) + weight_strategies = { + (group or {}).get("weights", {}).get("strategy") + for group in cfg_groups.values() + if isinstance(group, dict) + } + weight_strategies.discard(None) + if weight_strategies == {"block"} and quant_algo != "FP8_PB_WO": + issues.append( + f"weights.strategy='block' but quant_algo={quant_algo!r}. Per-block " + "weight scales require FP8_PB_WO so upstream vLLM dispatches to " + "ModelOptFp8PbWoLinearMethod; FP8 routes to per-tensor and crashes " + "on the 4D weight_scale at weight load time." + ) + elif quant_algo == "FP8_PB_WO" and weight_strategies != {"block"}: + issues.append( + f"quant_algo='FP8_PB_WO' but weights.strategy={weight_strategies!r} " + "(expected {'block'}). FP8_PB_WO consumers expect 4D per-block scales." + ) if issues: print("[A] WARN — config looks incomplete:") for issue in issues: print(f" - {issue}") return 2 - print("[A] PASS — config looks correct.") + print(f"[A] PASS — config looks correct (quant_algo={quant_algo}).") return 0 @@ -161,10 +192,27 @@ def _disk_size_gib(p: Path) -> float: return sum(f.stat().st_size for f in p.rglob("*") if f.is_file()) / (1024**3) +def _transformer_subdirs(root: Path) -> list[Path]: + """Return [/transformer, /transformer_2] for those that exist. + + Wan2.2 MoE A14B (T2V/I2V) and Wan2.2-VACE-A14B export TWO transformer + subfolders; single-transformer checkpoints just have `transformer/`. + Falls back to `[root]` if neither exists (e.g., a baseline directory + that wasn't structured as a diffusers repo). + """ + found = [root / name for name in ("transformer", "transformer_2") if (root / name).is_dir()] + return found if found else [root] + + def _check_size_vs_baseline(transformer_dir: Path, baseline: str | None) -> int: """Returns 0 always (informational only).""" - fp8_size = _disk_size_gib(transformer_dir) - print(f"\n[C] FP8 transformer disk size: {fp8_size:.2f} GiB") + # transformer_dir is /transformer; walk one level up so we can + # also pick up transformer_2/ for Wan2.2 MoE A14B checkpoints. + fp8_root = transformer_dir.parent + fp8_subdirs = _transformer_subdirs(fp8_root) + fp8_size = sum(_disk_size_gib(p) for p in fp8_subdirs) + fp8_label = " + ".join(p.name for p in fp8_subdirs) + print(f"\n[C] FP8 transformer disk size ({fp8_label}): {fp8_size:.2f} GiB") if baseline is None: print("[C] SKIP — pass --baseline to compare against BF16.") @@ -172,26 +220,54 @@ def _check_size_vs_baseline(transformer_dir: Path, baseline: str | None) -> int: baseline_path = Path(baseline) if not baseline_path.exists(): - # Try HF download. + # Treat `baseline` as an HF repo id and read from the local cache. + # Don't trigger a download: this script is meant to run AFTER + # quantize_*_modelopt_fp8.py, which already pulled the whole repo + # into the cache. local_files_only=True makes that assumption + # explicit — if the cache is empty we surface a clear error rather + # than silently kicking off a multi-GB download. try: from huggingface_hub import snapshot_download + from huggingface_hub.errors import LocalEntryNotFoundError except ImportError: print("[C] SKIP — huggingface_hub not installed and baseline not a local path.") return 0 - print(f" Downloading baseline transformer from HF: {baseline}") - baseline_path = Path(snapshot_download(baseline, allow_patterns=["transformer/*"])) + try: + baseline_path = Path(snapshot_download(baseline, local_files_only=True)) + except LocalEntryNotFoundError: + print( + f"[C] SKIP — '{baseline}' not found in local HF cache. " + "Run the matching quantize_*_modelopt_fp8.py first (it caches the BF16 repo), " + "or pass --baseline ." + ) + return 0 + print(f" Resolved baseline from HF cache: {baseline_path}") - bf16_dir = baseline_path / "transformer" if (baseline_path / "transformer").exists() else baseline_path - bf16_size = _disk_size_gib(bf16_dir) + bf16_subdirs = _transformer_subdirs(baseline_path) + bf16_size = sum(_disk_size_gib(p) for p in bf16_subdirs) if bf16_size == 0: - print(f"[C] WARN — baseline transformer dir empty: {bf16_dir}") + print(f"[C] WARN — baseline transformer dir empty: {baseline_path}") return 0 + bf16_label = " + ".join(p.name for p in bf16_subdirs) reduction = (1 - fp8_size / bf16_size) * 100 - print(f"[C] BF16 baseline transformer disk size: {bf16_size:.2f} GiB ({bf16_dir})") - print(f"[C] Disk reduction: {reduction:.1f}% (FP8 transformer is {fp8_size / bf16_size:.0%} of BF16)") + print(f"[C] BF16 baseline transformer disk size ({bf16_label}): {bf16_size:.2f} GiB ({baseline_path})") + print(f"[C] Disk reduction: {reduction:.1f}% (FP8 is {fp8_size / bf16_size:.0%} of BF16)") if reduction < 30: print("[C] WARN — FP8 should typically reduce disk by ~40-50%; <30% suggests partial quantization.") + + # Whole-repo view: includes VAE / text_encoder / tokenizer / scheduler / + # top-level metadata. Quantization only touches transformer(s) so this + # reduction is always smaller than the transformer-only one — but it's + # what the deployment footprint actually is. + fp8_total = _disk_size_gib(fp8_root) + bf16_total = _disk_size_gib(baseline_path) + if bf16_total > 0: + total_reduction = (1 - fp8_total / bf16_total) * 100 + print( + f"[C] Whole-repo: FP8 {fp8_total:.2f} GiB / BF16 {bf16_total:.2f} GiB " + f"(reduction {total_reduction:.1f}%, deployment footprint)" + ) return 0 diff --git a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py index 95a5cf91364..27643f05be8 100644 --- a/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py +++ b/examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Quantize HunyuanVideo-1.5 (480p T2V) to a ModelOpt FP8 Hugging Face checkpoint. +"""Quantize HunyuanVideo-1.5 to a ModelOpt FP8 Hugging Face checkpoint. Calibrates the DiT transformer using a small video prompt set and exports a diffusers-style directory whose transformer carries ModelOpt FP8 metadata. @@ -13,10 +13,36 @@ and final proj_out. MHA quantizers are off by default; HV-1.5 self-attention empirically degrades under FP8 (see #2920 ablation). -Example: - python examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py \\ - --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v \\ - --output ./hv15-480p-modelopt-fp8 \\ +Supported targets (T2V uses HunyuanVideo15Pipeline; I2V uses +HunyuanVideo15ImageToVideoPipeline. `--variant auto` detects from the loaded +class, but you can pin it with `--variant t2v|i2v`.): +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v` +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v` +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_i2v` +- `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_i2v` + +For I2V variants, diffusers' HunyuanVideo15ImageToVideoPipeline takes a +required `image` kwarg (and derives height/width from the image), so +calibration must pair every prompt with a reference image — pass +`--reference-images `. + +Recommended resolutions per variant (CLI overrides accepted; T2V uses these +defaults, I2V derives from the reference image and ignores --height/--width): +- 480p: --height 480 --width 832 (default) +- 720p: --height 720 --width 1280 + +Example (480p T2V): + python examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py \ + --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v \ + --output ./hv15-480p-t2v-modelopt-fp8 \ + --overwrite + +Example (480p I2V): + python examples/quantization/quantize_hunyuanvideo_15_modelopt_fp8.py \ + --model hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_i2v \ + --variant i2v \ + --reference-images /path/to/ref_images \ + --output ./hv15-480p-i2v-modelopt-fp8 \ --overwrite """ @@ -74,6 +100,23 @@ def _build_parser() -> argparse.ArgumentParser: default=[], help="Custom calibration prompt. Repeat to provide multiple.", ) + p.add_argument( + "--variant", + choices=("auto", "t2v", "i2v"), + default="auto", + help="HunyuanVideo-1.5 pipeline variant. `auto` detects from the loaded pipeline class " + "(HunyuanVideo15Pipeline -> t2v, HunyuanVideo15ImageToVideoPipeline -> i2v). " + "Pass `i2v` only if you also pass --reference-images.", + ) + p.add_argument( + "--reference-images", + type=str, + default=None, + help="Required for i2v variants. Directory of jpg/jpeg/png/webp files (or a single image). " + "Every calibration sample is paired with a cycled ref image since `image` is a required " + "kwarg, not optional, in HunyuanVideo15ImageToVideoPipeline. The pipeline derives " + "height/width from the image, so --height/--width are ignored under i2v.", + ) p.add_argument( "--quantize-mha", action="store_true", @@ -83,9 +126,10 @@ def _build_parser() -> argparse.ArgumentParser: "--weight-block-size", type=str, default=None, - help="Per-block weight quantization as 'M,N' (e.g. '128,128' for 128x128 tiles). " - "Default: per-tensor (one scale per linear). Block-wise typically gives tighter quality at " - "negligible memory cost. Static FP8 is exempt from upstream vLLM's online block-wise gate.", + help="Per-block weight quantization as 'M,N'. Only '128,128' is accepted because upstream " + "vLLM's ModelOptFp8PbWoLinearMethod hardcodes that block shape. Default: per-tensor. " + "Block-wise saves checkpoints with FP8_PB_WO routing (per-block static weights + per-token-" + "group dynamic activations); per-tensor uses static FP8 with calibrated activation scales.", ) p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") return p @@ -136,6 +180,47 @@ def _build_prompts(args: argparse.Namespace) -> list[str]: return prompts[: args.calib_size] +def _load_reference_images(spec: str | None) -> list[Any]: + """Load PIL.Image list from a directory or a single file path.""" + if spec is None: + return [] + from PIL import Image + + p = Path(spec).expanduser() + if not p.exists(): + raise SystemExit(f"--reference-images path not found: {p}") + if p.is_file(): + return [Image.open(p).convert("RGB")] + image_paths = sorted( + f for f in p.iterdir() if f.is_file() and f.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") + ) + if not image_paths: + raise SystemExit(f"No image files (jpg/jpeg/png/webp) found in {p}") + return [Image.open(f).convert("RGB") for f in image_paths] + + +def _resolve_variant(pipe: DiffusionPipeline, requested: str) -> str: + """Resolve --variant auto by inspecting the loaded pipeline class. + + HunyuanVideo15ImageToVideoPipeline -> i2v + HunyuanVideo15Pipeline (or anything else with no `image` kwarg) -> t2v + """ + if requested != "auto": + return requested + cls_name = pipe.__class__.__name__ + if "ImageToVideo" in cls_name: + return "i2v" + return "t2v" + + +def _build_calib_samples(prompts: list[str], variant: str, ref_images: list[Any]) -> list[tuple[str, Any]]: + """Pair each calibration prompt with a ref image (i2v) or None (t2v).""" + if variant == "i2v": + # ref_images is guaranteed non-empty by main()'s validation. + return [(prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(prompts)] + return [(prompt, None) for prompt in prompts] + + # Layers to KEEP at full precision (mirror of the #2920 wiring + #2728/#2795 skip pattern). # - x_embedder, image_embedder, context_embedder*, time_embed*, cond_type_embed: entry/embedding # - norm_out, norm1*.linear, norm1_context*.linear, norm2*, norm2_context*: AdaLayerNorm modulation @@ -175,7 +260,18 @@ def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: return pipe -def _build_forward_loop(pipe: DiffusionPipeline, args: argparse.Namespace, prompts: list[str]): +def _build_forward_loop( + pipe: DiffusionPipeline, + args: argparse.Namespace, + samples: list[tuple[str, Any]], + variant: str, +): + """Build a forward_loop over (prompt, ref_image) calibration samples. + + For i2v: HunyuanVideo15ImageToVideoPipeline derives height/width from the + image, so we pass `image=` and drop --height/--width. For t2v: standard + prompt-only path with --height/--width honored. + """ generator = torch.Generator(device="cuda") # Try to set guidance on the pipeline's guider object up front (modern @@ -188,26 +284,32 @@ def _build_forward_loop(pipe: DiffusionPipeline, args: argparse.Namespace, promp except Exception: pass - base_kwargs = dict( - height=args.height, - width=args.width, + base_kwargs: dict[str, Any] = dict( num_frames=args.num_frames, num_inference_steps=args.calib_steps, output_type="latent", ) + if variant != "i2v": + # I2V pipeline derives height/width from the input image and rejects + # these kwargs; only set them on T2V. + base_kwargs["height"] = args.height + base_kwargs["width"] = args.width def forward_loop(*_unused_args, **_unused_kwargs) -> None: with torch.inference_mode(): - for idx, prompt in enumerate(prompts): + for idx, (prompt, ref_image) in enumerate(samples): generator.manual_seed(args.seed + idx) + kwargs = dict(base_kwargs) + if ref_image is not None: + kwargs["image"] = ref_image # Try with guidance_scale first; fall back without on TypeError # for pipelines (like HV-1.5) that take CFG via guider config. try: - pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **base_kwargs) + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **kwargs) except TypeError as exc: if "guidance_scale" not in str(exc): raise - pipe(prompt=prompt, generator=generator, **base_kwargs) + pipe(prompt=prompt, generator=generator, **kwargs) return forward_loop @@ -270,10 +372,15 @@ def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtyp def _hv15_quant_config_block(weight_block_size: list[int] | None = None) -> dict: """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). - Same shape as the HunyuanImage-3 author's _hunyuan_quant_config(). When - `weight_block_size` is given, advertise block-wise weight quantization in - the saved metadata (so consumers know to expect multi-element scale tensors). + For per-block weight quantization,upstream's FP8_PB_WO hardcodes _WEIGHT_BLOCK_SIZE = (128, 128), so any other + block shape produces a checkpoint vLLM cannot serve. """ + if weight_block_size is not None and tuple(weight_block_size) != (128, 128): + raise ValueError( + f"--weight-block-size {tuple(weight_block_size)} not supported: upstream vLLM's " + "ModelOptFp8PbWoLinearMethod hardcodes (128, 128). Pass '128,128' or omit the flag." + ) + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} if weight_block_size is not None: weights_cfg["strategy"] = "block" @@ -302,7 +409,7 @@ def _hv15_quant_config_block(weight_block_size: list[int] | None = None) -> dict "x_embedder*", ], "producer": {"name": "modelopt"}, - "quant_algo": "FP8", + "quant_algo": "FP8_PB_WO" if weight_block_size is not None else "FP8", "quant_method": "modelopt", } @@ -370,23 +477,41 @@ def main() -> None: model_path, output_dir = _ensure_paths(args) dtype = _select_dtype(args.dtype) prompts = _build_prompts(args) - - print("Quantization plan:") weight_block_size = _parse_block_size(args.weight_block_size) + if args.reference_images is not None and args.variant == "t2v": + raise SystemExit("--reference-images is only meaningful with --variant i2v (or auto-detected i2v).") + + pipe = _load_pipeline(model_path, dtype) + variant = _resolve_variant(pipe, args.variant) + if variant == "i2v" and args.reference_images is None: + raise SystemExit( + "i2v variant requires --reference-images: HunyuanVideo15ImageToVideoPipeline " + "takes a required `image` kwarg, so calibration must pair every prompt with a " + "reference image." + ) + ref_images = _load_reference_images(args.reference_images) if variant == "i2v" else [] + samples = _build_calib_samples(prompts, variant, ref_images) + sample_label = f"i2v={len(samples)}" if variant == "i2v" else f"t2v={len(samples)}" + + print("Quantization plan:") print(f" input: {args.model}") print(f" output: {output_dir}") print(f" dtype: {dtype}") - print(f" height/width: {args.height}x{args.width}") + print(f" variant: {variant} (requested={args.variant}, class={pipe.__class__.__name__})") + if variant == "i2v": + print(" height/width: derived from reference image (i2v ignores --height/--width)") + print(f" reference imgs: {len(ref_images)}") + else: + print(f" height/width: {args.height}x{args.width}") print(f" num_frames: {args.num_frames}") - print(f" calib_size: {len(prompts)}") + print(f" calib_size: {len(samples)} ({sample_label})") print(f" calib_steps: {args.calib_steps}") print(f" quantize_mha: {args.quantize_mha}") print( f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" ) - pipe = _load_pipeline(model_path, dtype) backbone = pipe.transformer quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) @@ -402,7 +527,7 @@ def main() -> None: f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" ) - forward_loop = _build_forward_loop(pipe, args, prompts) + forward_loop = _build_forward_loop(pipe, args, samples, variant) quantized = mtq.quantize(backbone, quant_config, forward_loop) if quantized is not None: pipe.transformer = quantized @@ -428,16 +553,29 @@ def main() -> None: _summarize_export(output_dir) print("\nNext: validate the checkpoint with vllm-omni:") - print( - " python examples/offline_inference/text_to_video/text_to_video.py \\\n" - f" --model {output_dir} \\\n" - " --quantization fp8 \\\n" - " --prompt 'A dog running across a field of golden wheat.' \\\n" - f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" - " --num-inference-steps 30 --guidance-scale 6.0 --seed 42 \\\n" - " --output outputs/hv15_modelopt_fp8.mp4 \\\n" - " --enforce-eager" - ) + if variant == "i2v": + print( + " python examples/offline_inference/image_to_video/image_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A subject from the reference image moves through the scene.' \\\n" + " --image \\\n" + f" --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 6.0 --seed 42 \\\n" + " --output outputs/hv15_i2v_modelopt_fp8.mp4 \\\n" + " --enforce-eager" + ) + else: + print( + " python examples/offline_inference/text_to_video/text_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 6.0 --seed 42 \\\n" + " --output outputs/hv15_t2v_modelopt_fp8.mp4 \\\n" + " --enforce-eager" + ) print( "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " "checkpoint's config.json has modelopt metadata.)" diff --git a/examples/quantization/quantize_wan2_2_modelopt_fp8.py b/examples/quantization/quantize_wan2_2_modelopt_fp8.py index e489271b04f..c688bde2d43 100644 --- a/examples/quantization/quantize_wan2_2_modelopt_fp8.py +++ b/examples/quantization/quantize_wan2_2_modelopt_fp8.py @@ -1,10 +1,10 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Quantize Wan2.2 (TI2V-5B, 704x1280 T2V) to a ModelOpt FP8 Hugging Face checkpoint. +"""Quantize Wan2.2 to a ModelOpt FP8 Hugging Face checkpoint. -Calibrates the DiT transformer using a small video prompt set and exports a -diffusers-style directory whose transformer carries ModelOpt FP8 metadata. +Calibrates the DiT transformer(s) using a small video prompt set and exports a +diffusers-style directory whose transformer(s) carry ModelOpt FP8 metadata. The exported checkpoint is consumable by vllm-omni's ModelOpt FP8 adapter (see vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py). @@ -14,14 +14,41 @@ are quantized — static calibration handles the numerics that online FP8 couldn't (see #2920 ablation). -Default target is `Wan-AI/Wan2.2-TI2V-5B-Diffusers`, the dense 5B variant that -fits 80GB BF16. The A14B MoE variants need 2+ GPUs and are out of scope here. +Supported targets: +- `Wan-AI/Wan2.2-TI2V-5B-Diffusers` (single-transformer, 80GB BF16 fits one GPU) +- `Wan-AI/Wan2.2-T2V-A14B-Diffusers` (MoE, two transformers, needs 2+ GPUs BF16) +- `Wan-AI/Wan2.2-I2V-A14B-Diffusers` (MoE, two transformers, needs 2+ GPUs BF16) -Example: - python examples/quantization/quantize_wan2_2_modelopt_fp8.py \\ - --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \\ - --output ./wan22-ti2v-modelopt-fp8 \\ +For VACE variants (Wan-AI/Wan2.X-VACE-*), use the dedicated script +`quantize_wan2_2_vace_modelopt_fp8.py` instead + +For MoE A14B variants the diffusers pipeline routes between `transformer` (high +noise, t >= boundary_timestep) and `transformer_2` (low noise) automatically +based on `boundary_ratio` from `model_index.json`. A single calibration run +collects amax statistics for both via timestep-conditioned forward passes. + +For I2V variants diffusers' WanImageToVideoPipeline takes a required `image` +kwarg, so calibration must pair every prompt with a reference image — pass +`--is-i2v` together with `--reference-images `. + +Example(TI2V-5B): + python examples/quantization/quantize_wan2_2_modelopt_fp8.py \ + --model Wan-AI/Wan2.2-TI2V-5B-Diffusers \ + --output ./wan22-ti2v-modelopt-fp8 \ --overwrite +Example(T2V-A14B): + python examples/quantization/quantize_wan2_2_modelopt_fp8.py \ + --model Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --output ./wan22-t2v-modelopt-fp8 \ + --calib-boundary-ratio 0.5 \ + --overwrite +Example(I2V-A14B): + python examples/quantization/quantize_wan2_2_modelopt_fp8.py \ + --model Wan-AI/Wan2.2-I2V-A14B-Diffusers \ + --output ./wan22-i2v-modelopt-fp8 \ + --is-i2v --reference-images /path/to/ref_images/ \ + --calib-boundary-ratio 0.5 \ + --overwrite """ from __future__ import annotations @@ -71,7 +98,13 @@ def _build_parser() -> argparse.ArgumentParser: default=10, help="Denoising steps per calibration prompt (10 is enough for amax statistics).", ) - p.add_argument("--calib-size", type=int, default=8, help="How many prompts to use for calibration.") + p.add_argument( + "--calib-size", + type=int, + default=8, + help="How many prompts to use for calibration. It is now decoupled with " + "number of DEFAULT_PROMPTS, i.e. type any size you like", + ) p.add_argument("--seed", type=int, default=42) p.add_argument( "--prompt", @@ -89,9 +122,36 @@ def _build_parser() -> argparse.ArgumentParser: "--weight-block-size", type=str, default=None, - help="Per-block weight quantization as 'M,N' (e.g. '128,128'). Default per-tensor. " - "Note: vllm-omni's ModelOpt adapter may not yet dispatch block-wise scales — check #2924 " - "for the HV-1.5 investigation status before relying on this.", + help="Per-block weight quantization as 'M,N'. Only '128,128' is accepted because upstream " + "vLLM's ModelOptFp8PbWoLinearMethod hardcodes that block shape. Default: per-tensor. " + "Block-wise saves checkpoints with FP8_PB_WO routing (per-block static weights + per-token-" + "group dynamic activations); per-tensor uses static FP8 with calibrated activation scales.", + ) + p.add_argument( + "--calib-boundary-ratio", + type=float, + default=None, + help="Pass-1-only boundary_ratio override for Wan2.2 MoE calibration. Only takes " + "effect when the loaded pipeline has transformer_2. Lowering it (e.g. 0.5) shifts " + "more denoising steps onto `transformer` so its quantizers see a richer amax " + "sample WITHOUT bumping --calib-steps. Pass 2 always restores the model's " + "production boundary_ratio (A14B = 0.875) to keep transformer_2's amax in " + "production distribution. If unset, both passes use the production value (default).", + ) + p.add_argument( + "--is-i2v", + action="store_true", + help="Set when quantizing a Wan2.2 I2V model (e.g. Wan2.2-I2V-A14B-Diffusers). " + "diffusers' WanImageToVideoPipeline takes a required `image` kwarg, so calibration " + "must pair every prompt with a reference image — pass --reference-images.", + ) + p.add_argument( + "--reference-images", + type=str, + default=None, + help="Requires --is-i2v. Directory of jpg/jpeg/png/webp files (or a single image). " + "Every calibration sample is paired with a cycled ref image since image_embedder " + "is required, not optional, in I2V pipelines. Warning: one image per sample", ) p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") return p @@ -132,14 +192,46 @@ def _select_dtype(name: str) -> torch.dtype: return {"bfloat16": torch.bfloat16, "float16": torch.float16}[name] -def _build_prompts(args: argparse.Namespace) -> list[str]: - prompts = args.prompt or DEFAULT_PROMPTS +def _load_reference_images(spec: str | None) -> list[Any]: + """Load PIL.Image list from a directory or a single file path.""" + if spec is None: + return [] + from PIL import Image + + p = Path(spec).expanduser() + if not p.exists(): + raise SystemExit(f"--reference-images path not found: {p}") + if p.is_file(): + return [Image.open(p).convert("RGB")] + image_paths = sorted( + f for f in p.iterdir() if f.is_file() and f.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") + ) + if not image_paths: + raise SystemExit(f"No image files (jpg/jpeg/png/webp) found in {p}") + return [Image.open(f).convert("RGB") for f in image_paths] + + +def _build_calib_samples( + args: argparse.Namespace, + is_i2v: bool, + ref_images: list[Any], +) -> list[tuple[str, Any]]: + """Build calibration (prompt, reference_image_or_None) pairs. + + - Non-I2V (T2V/TI2V/A14B-T2V): every sample is (prompt, None). + - I2V: every sample paired with a cycled ref image (image kwarg is required + by diffusers' WanImageToVideoPipeline). Prompt pool is DEFAULT_PROMPTS + since the image dominates the visual signal — text mainly drives motion. + """ if args.calib_size <= 0: raise SystemExit("--calib-size must be positive.") - if len(prompts) < args.calib_size: - repeats = (args.calib_size + len(prompts) - 1) // len(prompts) - prompts = (prompts * repeats)[: args.calib_size] - return prompts[: args.calib_size] + + prompts = args.prompt or DEFAULT_PROMPTS + if is_i2v: + # ref_images is guaranteed non-empty by main()'s validation (--is-i2v + # requires --reference-images). + return [(prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(prompts)] + return [(prompt, None) for prompt in prompts] # Layers to KEEP at full precision. Wan2.2's module naming: @@ -174,15 +266,74 @@ def _disable_known_problematic_quantizers(mtq: Any, backbone: torch.nn.Module, * mtq.disable_quantizer(backbone, _mha_filter_func) +def _move_tensor(value: Any, device: torch.device) -> Any: + if isinstance(value, torch.Tensor): + return value.to(device) + if isinstance(value, (tuple, list)): + moved = [_move_tensor(v, device) for v in value] + return type(value)(moved) + return value + + +def _make_input_device_hook(target_device: torch.device): + """Pre-hook that moves all tensor args/kwargs onto the module's device.""" + + def pre_hook(_module, args, kwargs): + new_args = tuple(_move_tensor(a, target_device) for a in args) + new_kwargs = {k: _move_tensor(v, target_device) for k, v in kwargs.items()} + return new_args, new_kwargs + + return pre_hook + + +def _make_output_device_hook(primary_device: torch.device): + """Post-hook that moves outputs back to the pipeline's primary device.""" + + def post_hook(_module, _args, output): + return _move_tensor(output, primary_device) + + return post_hook + + def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) if hasattr(pipe, "set_progress_bar_config"): pipe.set_progress_bar_config(disable=True) - pipe.to("cuda") + + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None and torch.cuda.device_count() >= 2: + # diffusers' WanPipeline routes between the two by boundary_timestep but does + # NOT transfer activations across devices, so this case bridge transformer_2 with + # forward hooks: pre-hook moves inputs cuda:0 -> cuda:1, post-hook moves + # outputs back cuda:1 -> cuda:0. The pipeline then sees a uniform cuda:0 + # state and scheduler.step works without modification. + primary = torch.device("cuda:0") + secondary = torch.device("cuda:1") + pipe.transformer.to(primary) + transformer_2.to(secondary) + for component_name in ("text_encoder", "vae", "image_encoder"): + component = getattr(pipe, component_name, None) + if component is not None: + component.to(primary) + transformer_2.register_forward_pre_hook(_make_input_device_hook(secondary), with_kwargs=True) + transformer_2.register_forward_hook(_make_output_device_hook(primary)) + print(f" device map: transformer={primary}, transformer_2={secondary} (cross-device hooks installed)") + else: + pipe.to("cuda") return pipe -def _build_forward_loop(pipe: DiffusionPipeline, args: argparse.Namespace, prompts: list[str]): +def _build_forward_loop( + pipe: DiffusionPipeline, + args: argparse.Namespace, + samples: list[tuple[str, Any]], +): + """Build a forward_loop that drives `pipe` over the calibration samples. + + Samples carrying a reference image are forwarded with `image=PIL.Image` + (the kwarg expected by diffusers' WanImageToVideoPipeline). Samples with + ref=None call pipe(prompt=...) — the standard T2V path. + """ generator = torch.Generator(device="cuda") # Try setting guidance on the pipeline's guider if present (newer diffusers APIs). @@ -203,22 +354,25 @@ def _build_forward_loop(pipe: DiffusionPipeline, args: argparse.Namespace, promp def forward_loop(*_unused_args, **_unused_kwargs) -> None: with torch.inference_mode(): - for idx, prompt in enumerate(prompts): + for idx, (prompt, ref_image) in enumerate(samples): generator.manual_seed(args.seed + idx) + kwargs = dict(base_kwargs) + if ref_image is not None: + kwargs["image"] = ref_image # Try with guidance_scale first; fall back without on TypeError # for pipelines that take CFG via guider config only. try: - pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **base_kwargs) + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **kwargs) except TypeError as exc: if "guidance_scale" not in str(exc): raise - pipe(prompt=prompt, generator=generator, **base_kwargs) + pipe(prompt=prompt, generator=generator, **kwargs) return forward_loop -def _summarize_export(output_dir: Path) -> None: - cfg_path = output_dir / "transformer" / "config.json" +def _summarize_export(output_dir: Path, subfolder: str = "transformer") -> None: + cfg_path = output_dir / subfolder / "config.json" if not cfg_path.exists(): print(f"[warn] {cfg_path} missing.", file=sys.stderr) return @@ -226,9 +380,9 @@ def _summarize_export(output_dir: Path) -> None: cfg = json.load(f) qc = cfg.get("quantization_config") if not isinstance(qc, dict): - print("[warn] No quantization_config in transformer/config.json.", file=sys.stderr) + print(f"[warn] No quantization_config in {subfolder}/config.json.", file=sys.stderr) return - print("Export summary:") + print(f"Export summary ({subfolder}):") print(f" quant_method: {qc.get('quant_method')}") print(f" quant_algo: {qc.get('quant_algo')}") producer = qc.get("producer") @@ -273,7 +427,17 @@ def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtyp def _wan22_quant_config_block(weight_block_size: list[int] | None = None) -> dict: - """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913).""" + """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). + + For per-block weight quantization,upstream's FP8_PB_WO hardcodes _WEIGHT_BLOCK_SIZE = (128, 128), so any other + block shape produces a checkpoint vLLM cannot serve. + """ + if weight_block_size is not None and tuple(weight_block_size) != (128, 128): + raise ValueError( + f"--weight-block-size {tuple(weight_block_size)} not supported: upstream vLLM's " + "ModelOptFp8PbWoLinearMethod hardcodes (128, 128). Pass '128,128' or omit the flag." + ) + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} if weight_block_size is not None: weights_cfg["strategy"] = "block" @@ -296,15 +460,23 @@ def _wan22_quant_config_block(weight_block_size: list[int] | None = None) -> dic "timestep_proj_prepare*", ], "producer": {"name": "modelopt"}, - "quant_algo": "FP8", + "quant_algo": "FP8_PB_WO" if weight_block_size is not None else "FP8", "quant_method": "modelopt", } -def _patch_quant_config(output_dir: Path, weight_block_size: list[int] | None = None) -> None: - """Inject quant_algo: FP8 + config_groups into transformer/config.json so - vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8.""" - cfg_path = output_dir / "transformer" / "config.json" +def _patch_quant_config( + output_dir: Path, + subfolder: str = "transformer", + weight_block_size: list[int] | None = None, +) -> None: + """Inject quant_algo: FP8 + config_groups into /config.json so + vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8. + + For Wan2.2 MoE (T2V/I2V-A14B), call once per transformer subfolder + (`transformer` and `transformer_2`). + """ + cfg_path = output_dir / subfolder / "config.json" with cfg_path.open(encoding="utf-8") as f: cfg = json.load(f) @@ -320,13 +492,17 @@ def _patch_quant_config(output_dir: Path, weight_block_size: list[int] | None = json.dump(cfg, f, indent=2) -def _save_pipeline_with_fp8_transformer( +def _save_pipeline_with_fp8_transformers( pipe: DiffusionPipeline, model_path: str, output_dir: Path, max_shard_size: str = "5GB", ) -> None: - """Copy source dir verbatim minus transformer/, then save the quantized transformer.""" + """Copy source dir verbatim minus transformer/(_2), then save quantized transformer(s). + + For Wan2.2 MoE (T2V/I2V-A14B), `pipe.transformer_2` is also saved into the + `transformer_2/` subfolder. Single-transformer variants (TI2V-5B) skip it. + """ from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict src = Path(model_path) @@ -339,13 +515,56 @@ def _save_pipeline_with_fp8_transformer( shutil.rmtree(output_dir) shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer", "transformer_2")) - transformer_out = output_dir / "transformer" - # Pass the nn.Module (transformer), not the Pipeline wrapper. - with hide_quantizers_from_state_dict(pipe.transformer): - pipe.transformer.save_pretrained( - str(transformer_out), - safe_serialization=True, - max_shard_size=max_shard_size, + backbones: list[tuple[str, torch.nn.Module]] = [("transformer", pipe.transformer)] + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None: + backbones.append(("transformer_2", transformer_2)) + + for subfolder, backbone in backbones: + out = output_dir / subfolder + with hide_quantizers_from_state_dict(backbone): + backbone.save_pretrained( + str(out), + safe_serialization=True, + max_shard_size=max_shard_size, + ) + + +def _calibrate( + backbone: torch.nn.Module, + label: str, + *, + mtq: Any, + quant_config: dict, + forward_loop, + quantize_mha: bool, +) -> torch.nn.Module: + """Wrap one transformer backbone with quantizers and run calibration. + + Returns the (possibly replaced) backbone module so the caller can rebind + `pipe.transformer` / `pipe.transformer_2` to the wrapped instance. The + backbone's weights remain in their original dtype here — call + `_force_export` afterwards to commit FP8 storage. + """ + print(f"\nCalibrating {label}...") + quantized = mtq.quantize(backbone, quant_config, forward_loop) + if quantized is not None: + backbone = quantized + _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=quantize_mha) + return backbone + + +def _force_export(backbone: torch.nn.Module, label: str, dtype: torch.dtype) -> None: + """Convert calibrated weights to actual FP8 storage.""" + print(f"\nForcing FP8 weight serialization for {label} (Wan2.2 isn't in ModelOpt's") + print("recognized-model registry, so we call the per-weight export helper ourselves)...") + exported = _force_export_quantized_weights(backbone, dtype) + print(f" -> {exported} weights converted to FP8 in {label}") + if exported == 0: + raise SystemExit( + f"No quantized weights were exported in {label}. Calibration may have skipped every " + "layer (check the disable_quantizer regex) or `mtq.quantize` did not actually wrap " + "any weight quantizers." ) @@ -357,24 +576,47 @@ def main() -> None: mtq = _require_modelopt() model_path, output_dir = _ensure_paths(args) dtype = _select_dtype(args.dtype) - prompts = _build_prompts(args) weight_block_size = _parse_block_size(args.weight_block_size) + if args.reference_images is not None and not args.is_i2v: + raise SystemExit("--reference-images requires --is-i2v.") + if args.is_i2v and args.reference_images is None: + raise SystemExit( + "--is-i2v requires --reference-images: diffusers' WanImageToVideoPipeline " + "takes a required `image` kwarg, so calibration must pair every prompt with " + "a reference image." + ) + ref_images = _load_reference_images(args.reference_images) if args.is_i2v else [] + samples = _build_calib_samples(args, args.is_i2v, ref_images) + sample_label = f"I2V={len(samples)}" if args.is_i2v else f"T2V={len(samples)}" + print("Quantization plan:") print(f" input: {args.model}") print(f" output: {output_dir}") print(f" dtype: {dtype}") print(f" height/width: {args.height}x{args.width}") print(f" num_frames: {args.num_frames}") - print(f" calib_size: {len(prompts)}") + print(f" calib_size: {len(samples)} ({sample_label})") print(f" calib_steps: {args.calib_steps}") print(f" quantize_mha: {args.quantize_mha}") + print(f" is_i2v: {args.is_i2v}") + if args.is_i2v: + print(f" reference imgs: {len(ref_images)}") print( f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" ) pipe = _load_pipeline(model_path, dtype) - backbone = pipe.transformer + is_dual = getattr(pipe, "transformer_2", None) is not None + if is_dual: + print(" detected MoE A14B variant (transformer + transformer_2)") + + # Capture the model's production boundary_ratio (from model_index.json) so + # we can restore it before pass 2. --calib-boundary-ratio only overrides + # pass 1 to give `transformer` more amax samples; pass 2 must run at the + # production boundary so `transformer_2` calibrates on the same noise + # distribution it will see at inference time. + production_boundary = pipe.config.get("boundary_ratio") if is_dual else None quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) if weight_block_size is not None: @@ -387,41 +629,83 @@ def main() -> None: f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" ) - forward_loop = _build_forward_loop(pipe, args, prompts) - quantized = mtq.quantize(backbone, quant_config, forward_loop) - if quantized is not None: - pipe.transformer = quantized - backbone = quantized - - _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=args.quantize_mha) + forward_loop = _build_forward_loop(pipe, args, samples) + + # Single-transformer (TI2V-5B) does one pass; MoE A14B variants do two. + # The diffusers Wan22 pipeline routes between transformer (high noise) and + # transformer_2 (low noise) by boundary_timestep, so each forward_loop run + # exercises the backbone currently being calibrated. mtq.quantize wraps + # quantizers and then drives the forward_loop to collect amax statistics. + # + # Calibration must complete for BOTH backbones BEFORE any force_export call: + # Before _force_export, transformer's weights must still be BF16 at that point. + if is_dual and args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=args.calib_boundary_ratio) + print( + f"\n pass 1 boundary_ratio: {args.calib_boundary_ratio} " + f"(override of production {production_boundary} for transformer sample boost)" + ) - print("\nForcing FP8 weight serialization (Wan2.2 isn't in ModelOpt's recognized-model registry,") - print("so we have to call the per-weight export helper ourselves)...") - exported = _force_export_quantized_weights(backbone, dtype) - print(f" -> {exported} weights converted to FP8 in memory") - if exported == 0: - raise SystemExit( - "No quantized weights were exported. Calibration may have skipped every layer " - "(check the disable_quantizer regex) or `mtq.quantize` did not actually wrap any " - "weight quantizers." + pipe.transformer = _calibrate( + pipe.transformer, + "transformer", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, + ) + if is_dual: + if args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=production_boundary) + print( + f"\n pass 2 boundary_ratio: {production_boundary} " + "(restored to production for transformer_2 in-distribution calibration)" + ) + pipe.transformer_2 = _calibrate( + pipe.transformer_2, + "transformer_2", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, ) - print("\nSaving pipeline with FP8 transformer...") - _save_pipeline_with_fp8_transformer(pipe, model_path, output_dir) - _patch_quant_config(output_dir, weight_block_size=weight_block_size) + _force_export(pipe.transformer, "transformer", dtype) + if is_dual: + _force_export(pipe.transformer_2, "transformer_2", dtype) + + print("\nSaving pipeline with FP8 transformer(s)...") + _save_pipeline_with_fp8_transformers(pipe, model_path, output_dir) + _patch_quant_config(output_dir, subfolder="transformer", weight_block_size=weight_block_size) + if is_dual: + _patch_quant_config(output_dir, subfolder="transformer_2", weight_block_size=weight_block_size) print(f"Saved to: {output_dir}") - _summarize_export(output_dir) + _summarize_export(output_dir, subfolder="transformer") + if is_dual: + _summarize_export(output_dir, subfolder="transformer_2") print("\nNext: validate the checkpoint with vllm-omni:") - print( - " python examples/offline_inference/text_to_video/text_to_video.py \\\n" - f" --model {output_dir} \\\n" - " --quantization fp8 \\\n" - " --prompt 'A dog running across a field of golden wheat.' \\\n" - f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" - " --num-inference-steps 30 --guidance-scale 5.0 --seed 42 \\\n" - " --output outputs/wan22_modelopt_fp8.mp4" - ) + if args.is_i2v: + print( + " python examples/offline_inference/image_to_video/image_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A subject from the reference image moves through the scene.' \\\n" + " --image \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 --seed 42 \\\n" + " --output outputs/wan22_i2v_modelopt_fp8.mp4" + ) + else: + print( + " python examples/offline_inference/text_to_video/text_to_video.py \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 --seed 42 \\\n" + " --output outputs/wan22_modelopt_fp8.mp4" + ) print( "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " "checkpoint's config.json has modelopt metadata.)" diff --git a/examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py b/examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py new file mode 100644 index 00000000000..6a3f9143001 --- /dev/null +++ b/examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py @@ -0,0 +1,701 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantize Wan VACE models to a ModelOpt FP8 Hugging Face checkpoint. Support modes: + + - T2V : prompt only (vace_context auto-filled with zeros + mask=1) + - R2V : prompt + reference_images=[PIL.Image] + - I2V : prompt + reference_images=[PIL.Image], same as R2V in calibration step + +This script currently calibrates with **T2V + R2V + I2V** samples. The other modes +(I2V/FLF2V/inpaint) require encoded video + mask inputs and can be wired in +later by extending `_build_calib_samples`. + +Layers kept full precision match the Wan2.2 pattern: condition embedder +(time/text/image), patch embedding, modulation (scale_shift_table), final +norm + proj_out, and sequence-parallel helpers. All attention + FFN linears +are quantized — including the vace_blocks' own attention/FFN linears (since +they're standard `WanTransformerBlock` subclasses). + +Supported targets: +- `Wan-AI/Wan2.1-VACE-1.3B-diffusers` (single-transformer, ~10GB BF16) +- `Wan-AI/Wan2.1-VACE-14B-diffusers` (single-transformer, ~38GB BF16) +- `Wan-AI/Wan2.2-VACE-A14B-Diffusers` (MoE + VACE, dual-transformer; needs 2+ GPUs BF16; Model not released yet, but the wiring is ready) + +For dual-transformer VACE the diffusers pipeline routes between `transformer` +and `transformer_2` by `boundary_timestep` exactly like Wan2.2 MoE T2V/I2V. + +Example (VACE T2V calibration, no reference images): + python examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py \ + --model Wan-AI/Wan2.1-VACE-1.3B-diffusers \ + --output ./wan21-vace-1.3b-fp8 \ + --overwrite + +Example (VACE T2V + R2V mix, with reference images): + python examples/quantization/quantize_wan2_2_vace_modelopt_fp8.py \ + --model Wan-AI/Wan2.1-VACE-14B-diffusers \ + --output ./wan21-vace-14b-fp8 \ + --reference-images /path/to/ref_images/ \ + --overwrite +""" + +from __future__ import annotations + +import argparse +import copy +import json +import re +import shutil +import sys +from pathlib import Path +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +DEFAULT_PROMPTS = [ + "A dog running across a field of golden wheat.", + "An astronaut riding a horse across the surface of Mars, red dust swirling, cinematic wide shot.", + "A hummingbird hovering in front of a vibrant red flower, slow motion, macro shot.", + "A crackling campfire at night under a starry sky, sparks rising into the dark.", + "An underwater shot of a coral reef with tropical fish swimming by, sun rays piercing the water.", + "A close-up of a blooming rose covered in morning dew, soft natural light.", + "A peaceful mountain village at dawn, mist rolling over the rooftops, cinematic establishing shot.", + "A skateboarder doing a kickflip in an urban plaza, slow motion, golden hour lighting.", +] + +# R2V prompts pair with --reference-images. Phrasing explicitly references "the +# subject from the reference image" so prompt and ref_image are semantically +# coupled — mimics how users actually write R2V prompts in production. +VACE_DEFAULT_PROMPTS_R2V = [ + "The subject from the reference image walks confidently through a snowy forest at dusk.", + "Recreate the reference subject dancing under spinning disco lights in a vibrant nightclub.", + "The reference subject sails across a calm ocean at golden hour, sun glinting off the water.", + "Render the reference subject in a cyberpunk cityscape at night, neon reflections on rainy streets.", +] + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--model", required=True, help="Input Wan VACE diffusers directory or HF id.") + p.add_argument("--output", required=True, help="Output directory for the ModelOpt FP8 checkpoint.") + p.add_argument("--dtype", choices=("bfloat16", "float16"), default="bfloat16") + p.add_argument("--height", type=int, default=480, help="Calibration video height (VACE 480p default).") + p.add_argument("--width", type=int, default=832, help="Calibration video width (VACE 480p default).") + p.add_argument( + "--num-frames", + type=int, + default=33, + help="Frames per calibration sample. Smaller frame counts reduce memory pressure during " + "calibration; amax statistics are largely independent of frame count.", + ) + p.add_argument("--guidance-scale", type=float, default=5.0) + p.add_argument( + "--calib-steps", + type=int, + default=10, + help="Denoising steps per calibration prompt (10 is enough for amax statistics).", + ) + p.add_argument("--calib-size", type=int, default=8, help="How many prompts to use for calibration.") + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--prompt", + action="append", + default=[], + help="Custom calibration prompt. Repeat to provide multiple. When --reference-images is " + "set, every custom prompt is paired with a cycled ref image (assumes R2V phrasing).", + ) + p.add_argument( + "--quantize-mha", + action="store_true", + help="Enable FP8 attention K/V/softmax quantizers. Off by default — Wan's long attention " + "sequences amplified FP8 drift in the online ablation (see #2920).", + ) + p.add_argument( + "--weight-block-size", + type=str, + default=None, + help="Per-block weight quantization as 'M,N'. Only '128,128' is accepted because upstream " + "vLLM's ModelOptFp8PbWoLinearMethod hardcodes that block shape. Default: per-tensor. " + "Block-wise saves checkpoints with FP8_PB_WO routing (per-block static weights + per-token-" + "group dynamic activations); per-tensor uses static FP8 with calibrated activation scales.", + ) + p.add_argument( + "--calib-boundary-ratio", + type=float, + default=None, + help="Pass-1-only boundary_ratio override for dual-transformer VACE (e.g. Wan2.2-VACE-A14B). " + "Lowering it (e.g. 0.5) shifts more denoising steps onto `transformer` so its quantizers " + "see a richer amax sample WITHOUT bumping --calib-steps. Pass 2 always restores the " + "model's production boundary_ratio. No-op for single-transformer VACE (Wan2.1-VACE-*).", + ) + p.add_argument( + "--reference-images", + type=str, + default=None, + help="Optional. Directory of jpg/jpeg/png/webp files (or a single image). When provided, " + "half the calibration samples become R2V (paired with cycled ref images, using " + "VACE_DEFAULT_PROMPTS_R2V) so vace_blocks' amax covers real ref-image latent " + "distributions; the other half stay T2V (zero-conditioning). When omitted, calibration " + "runs T2V-only — vace_blocks see auto-generated zero vace_context, which works but " + "amax is conservative for R2V-mode inference.", + ) + p.add_argument("--overwrite", action="store_true", help="Replace an existing output directory.") + return p + + +def _parse_block_size(spec: str | None) -> list[int] | None: + if spec is None: + return None + parts = [int(x) for x in spec.split(",") if x.strip()] + if len(parts) != 2: + raise SystemExit(f"--weight-block-size must be 'M,N' (2 ints), got {spec!r}") + return parts + + +def _require_modelopt() -> Any: + try: + import modelopt.torch.quantization as mtq + except ModuleNotFoundError as exc: + raise SystemExit( + "NVIDIA ModelOpt is not installed. Install with:\n" + " pip install 'nvidia-modelopt[all]'\n" + f"Original error: {exc}" + ) from exc + return mtq + + +def _ensure_paths(args: argparse.Namespace) -> tuple[str, Path]: + model_path = args.model + output_dir = Path(args.output).expanduser().resolve() + if output_dir.exists(): + if not args.overwrite: + raise SystemExit(f"Output directory already exists: {output_dir}\nPass --overwrite to replace it.") + shutil.rmtree(output_dir) + return model_path, output_dir + + +def _select_dtype(name: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float16": torch.float16}[name] + + +def _load_reference_images(spec: str | None) -> list[Any]: + """Load PIL.Image list from a directory or a single file path.""" + if spec is None: + return [] + from PIL import Image + + p = Path(spec).expanduser() + if not p.exists(): + raise SystemExit(f"--reference-images path not found: {p}") + if p.is_file(): + return [Image.open(p).convert("RGB")] + image_paths = sorted( + f for f in p.iterdir() if f.is_file() and f.suffix.lower() in (".jpg", ".jpeg", ".png", ".webp") + ) + if not image_paths: + raise SystemExit(f"No image files (jpg/jpeg/png/webp) found in {p}") + return [Image.open(f).convert("RGB") for f in image_paths] + + +def _cycle_to_size(items: list, size: int) -> list: + if not items: + raise SystemExit("Cannot build calibration prompts: pool is empty.") + repeats = (size + len(items) - 1) // len(items) + return (items * repeats)[:size] + + +def _build_calib_samples( + args: argparse.Namespace, + ref_images: list[Any], +) -> list[tuple[str, Any]]: + """Build calibration (prompt, reference_image_or_None) pairs for VACE. + + - No --reference-images: T2V-only calibration. vace_blocks see auto-generated + zero vace_context (vae.encode(zeros) + mask=1). + - With --reference-images, no --prompt: half samples are T2V (DEFAULT_PROMPTS), + half are R2V (VACE_DEFAULT_PROMPTS_R2V paired with cycled ref images), + covering both zero- and real-conditioning extremes for vace_blocks' amax. + - With --reference-images and --prompt: every user prompt is paired with a + cycled ref image (assumes the user wrote R2V-style prompts). + """ + if args.calib_size <= 0: + raise SystemExit("--calib-size must be positive.") + + if not ref_images: + prompts = args.prompt or DEFAULT_PROMPTS + return [(p, None) for p in _cycle_to_size(prompts, args.calib_size)] + + custom_prompts = args.prompt or [] + if custom_prompts: + pool = _cycle_to_size(custom_prompts, args.calib_size) + return [(prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(pool)] + + n_r2v = min(args.calib_size // 2, len(ref_images)) + n_t2v = args.calib_size - n_r2v + t2v_pool = _cycle_to_size(DEFAULT_PROMPTS, n_t2v) if n_t2v > 0 else [] + r2v_pool = _cycle_to_size(VACE_DEFAULT_PROMPTS_R2V, n_r2v) if n_r2v > 0 else [] + samples: list[tuple[str, Any]] = [(p, None) for p in t2v_pool] + samples.extend((prompt, ref_images[i % len(ref_images)]) for i, prompt in enumerate(r2v_pool)) + return samples + + +# Layers to KEEP at full precision. Wan VACE inherits the base Wan module +# naming (condition_embedder, patch_embedding, scale_shift_table, norm_out, +# proj_out, timestep_proj_prepare/output_scale_shift_prepare). vace_blocks +# carry their own proj_in/proj_out Linears (full path: vace_blocks.{i}.proj_*), +# which the regex below intentionally does NOT match — they are quantized +# alongside the rest of the vace_blocks' attention/FFN linears. +def _filter_func_wan22(name: str) -> bool: + pattern = re.compile( + r"(proj_out.*|" + r".*(condition_embedder|patch_embedding|" + r"norm_out|scale_shift_table|" + r"timestep_proj_prepare|output_scale_shift_prepare).*)" + ) + return pattern.match(name) is not None + + +def _mha_filter_func(name: str) -> bool: + pattern = re.compile( + r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer|bmm2_output_quantizer).*" + ) + return pattern.match(name) is not None + + +def _disable_known_problematic_quantizers(mtq: Any, backbone: torch.nn.Module, *, quantize_mha: bool) -> None: + if not hasattr(mtq, "disable_quantizer"): + return + mtq.disable_quantizer(backbone, _filter_func_wan22) + if not quantize_mha: + mtq.disable_quantizer(backbone, _mha_filter_func) + + +def _move_tensor(value: Any, device: torch.device) -> Any: + if isinstance(value, torch.Tensor): + return value.to(device) + if isinstance(value, (tuple, list)): + moved = [_move_tensor(v, device) for v in value] + return type(value)(moved) + return value + + +def _make_input_device_hook(target_device: torch.device): + """Pre-hook that moves all tensor args/kwargs onto the module's device.""" + + def pre_hook(_module, args, kwargs): + new_args = tuple(_move_tensor(a, target_device) for a in args) + new_kwargs = {k: _move_tensor(v, target_device) for k, v in kwargs.items()} + return new_args, new_kwargs + + return pre_hook + + +def _make_output_device_hook(primary_device: torch.device): + """Post-hook that moves outputs back to the pipeline's primary device.""" + + def post_hook(_module, _args, output): + return _move_tensor(output, primary_device) + + return post_hook + + +def _load_pipeline(model_path: str, dtype: torch.dtype) -> DiffusionPipeline: + pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) + if hasattr(pipe, "set_progress_bar_config"): + pipe.set_progress_bar_config(disable=True) + + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None and torch.cuda.device_count() >= 2: + # diffusers' WanPipeline routes between the two by boundary_timestep but + # does NOT transfer activations across devices; bridge transformer_2 with + # forward hooks: pre-hook moves inputs cuda:0 -> cuda:1, post-hook moves + # outputs back cuda:1 -> cuda:0. The pipeline then sees a uniform cuda:0 + # state and scheduler.step works without modification. + primary = torch.device("cuda:0") + secondary = torch.device("cuda:1") + pipe.transformer.to(primary) + transformer_2.to(secondary) + for component_name in ("text_encoder", "vae", "image_encoder"): + component = getattr(pipe, component_name, None) + if component is not None: + component.to(primary) + transformer_2.register_forward_pre_hook(_make_input_device_hook(secondary), with_kwargs=True) + transformer_2.register_forward_hook(_make_output_device_hook(primary)) + print(f" device map: transformer={primary}, transformer_2={secondary} (cross-device hooks installed)") + else: + pipe.to("cuda") + return pipe + + +def _build_forward_loop( + pipe: DiffusionPipeline, + args: argparse.Namespace, + samples: list[tuple[str, Any]], +): + """Build a forward_loop that drives `pipe` over the calibration samples. + + Samples carrying a reference image are forwarded with `reference_images=[img]` + (the kwarg expected by diffusers' WanVACEPipeline). Samples with ref=None + call pipe(prompt=...) — diffusers VACE pipeline auto-fills vace_context with + zeros + mask=1 in this case (T2V mode). + """ + generator = torch.Generator(device="cuda") + + # Try setting guidance on the pipeline's guider if present (newer diffusers APIs). + guider = getattr(pipe, "guider", None) + if guider is not None and hasattr(guider, "guidance_scale"): + try: + guider.guidance_scale = args.guidance_scale + except Exception: + pass + + base_kwargs = dict( + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.calib_steps, + output_type="latent", + ) + + def forward_loop(*_unused_args, **_unused_kwargs) -> None: + with torch.inference_mode(): + for idx, (prompt, ref_image) in enumerate(samples): + generator.manual_seed(args.seed + idx) + kwargs = dict(base_kwargs) + if ref_image is not None: + kwargs["reference_images"] = [ref_image] + # Try with guidance_scale first; fall back without on TypeError + # for pipelines that take CFG via guider config only. + try: + pipe(prompt=prompt, generator=generator, guidance_scale=args.guidance_scale, **kwargs) + except TypeError as exc: + if "guidance_scale" not in str(exc): + raise + pipe(prompt=prompt, generator=generator, **kwargs) + + return forward_loop + + +def _summarize_export(output_dir: Path, subfolder: str = "transformer") -> None: + cfg_path = output_dir / subfolder / "config.json" + if not cfg_path.exists(): + print(f"[warn] {cfg_path} missing.", file=sys.stderr) + return + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + qc = cfg.get("quantization_config") + if not isinstance(qc, dict): + print(f"[warn] No quantization_config in {subfolder}/config.json.", file=sys.stderr) + return + print(f"Export summary ({subfolder}):") + print(f" quant_method: {qc.get('quant_method')}") + print(f" quant_algo: {qc.get('quant_algo')}") + producer = qc.get("producer") + if isinstance(producer, dict): + print(f" producer: {producer.get('name')} {producer.get('version')}") + print(f" config path: {cfg_path}") + + +def _force_export_quantized_weights(backbone: torch.nn.Module, dtype: torch.dtype) -> int: + """Convert in-memory weights of quantized modules to actual FP8 storage. + + `export_hf_checkpoint` skips this step for unknown model types (Wan VACE + isn't in ModelOpt's recognized-model registry), so we must call the + per-weight export helper ourselves. + """ + from modelopt.torch.export.quant_utils import ( + QUANTIZATION_NONE, + get_quantization_format, + quantizer_attr_names, + weight_attr_names, + ) + from modelopt.torch.export.unified_export_hf import _export_quantized_weight + + exported = 0 + for name, module in backbone.named_modules(): + try: + quantization_format = get_quantization_format(module) + except Exception as exc: + print(f"[warn] Could not inspect quantization format for {name}: {exc}", file=sys.stderr) + continue + if quantization_format == QUANTIZATION_NONE: + continue + for weight_name in weight_attr_names(module): + quantizer_attrs = quantizer_attr_names(weight_name) + weight_quantizer = getattr(module, quantizer_attrs.weight_quantizer, None) + if weight_quantizer is None or not getattr(weight_quantizer, "is_enabled", False): + continue + _export_quantized_weight(module, dtype, weight_name) + exported += 1 + return exported + + +def _wan22_quant_config_block(weight_block_size: list[int] | None = None) -> dict: + """Mirror ModelOpt FP8 metadata expected by vllm-omni's adapter (#2913). + For per-block weight quantization, upstream's FP8_PB_WO hardcodes _WEIGHT_BLOCK_SIZE = (128, 128), so any other + block shape produces a checkpoint vLLM cannot serve. + """ + if weight_block_size is not None and tuple(weight_block_size) != (128, 128): + raise ValueError( + f"--weight-block-size {tuple(weight_block_size)} not supported: upstream vLLM's " + "ModelOptFp8PbWoLinearMethod hardcodes (128, 128). Pass '128,128' or omit the flag." + ) + + weights_cfg: dict = {"dynamic": False, "num_bits": 8, "type": "float"} + if weight_block_size is not None: + weights_cfg["strategy"] = "block" + weights_cfg["block_structure"] = f"{weight_block_size[0]}x{weight_block_size[1]}" + return { + "config_groups": { + "group_0": { + "input_activations": {"dynamic": False, "num_bits": 8, "type": "float"}, + "weights": weights_cfg, + "targets": ["Linear"], + } + }, + "ignore": [ + "condition_embedder*", + "norm_out*", + "output_scale_shift_prepare*", + "patch_embedding*", + "proj_out*", + "scale_shift_table*", + "timestep_proj_prepare*", + ], + "producer": {"name": "modelopt"}, + "quant_algo": "FP8_PB_WO" if weight_block_size is not None else "FP8", + "quant_method": "modelopt", + } + + +def _patch_quant_config( + output_dir: Path, + subfolder: str = "transformer", + weight_block_size: list[int] | None = None, +) -> None: + """Inject quant_algo: FP8 + config_groups into /config.json so + vllm-omni's adapter (#2913) recognises the checkpoint as ModelOpt FP8. + """ + cfg_path = output_dir / subfolder / "config.json" + with cfg_path.open(encoding="utf-8") as f: + cfg = json.load(f) + + new_qc = _wan22_quant_config_block(weight_block_size=weight_block_size) + existing = cfg.get("quantization_config") + if isinstance(existing, dict): + producer = existing.get("producer") + if isinstance(producer, dict): + new_qc["producer"] = producer + + cfg["quantization_config"] = new_qc + with cfg_path.open("w", encoding="utf-8") as f: + json.dump(cfg, f, indent=2) + + +def _save_pipeline_with_fp8_transformers( + pipe: DiffusionPipeline, + model_path: str, + output_dir: Path, + max_shard_size: str = "5GB", +) -> None: + """Copy source dir verbatim minus transformer/(_2), then save quantized transformer(s).""" + from modelopt.torch.export.diffusers_utils import hide_quantizers_from_state_dict + + src = Path(model_path) + if not src.exists(): + from huggingface_hub import snapshot_download + + src = Path(snapshot_download(model_path)) + + if output_dir.exists(): + shutil.rmtree(output_dir) + shutil.copytree(src, output_dir, ignore=shutil.ignore_patterns("transformer", "transformer_2")) + + backbones: list[tuple[str, torch.nn.Module]] = [("transformer", pipe.transformer)] + transformer_2 = getattr(pipe, "transformer_2", None) + if transformer_2 is not None: + backbones.append(("transformer_2", transformer_2)) + + for subfolder, backbone in backbones: + out = output_dir / subfolder + with hide_quantizers_from_state_dict(backbone): + backbone.save_pretrained( + str(out), + safe_serialization=True, + max_shard_size=max_shard_size, + ) + + +def _calibrate( + backbone: torch.nn.Module, + label: str, + *, + mtq: Any, + quant_config: dict, + forward_loop, + quantize_mha: bool, +) -> torch.nn.Module: + """Wrap one transformer backbone with quantizers and run calibration.""" + print(f"\nCalibrating {label}...") + quantized = mtq.quantize(backbone, quant_config, forward_loop) + if quantized is not None: + backbone = quantized + _disable_known_problematic_quantizers(mtq, backbone, quantize_mha=quantize_mha) + return backbone + + +def _force_export(backbone: torch.nn.Module, label: str, dtype: torch.dtype) -> None: + """Convert calibrated weights to actual FP8 storage.""" + print(f"\nForcing FP8 weight serialization for {label} (Wan VACE isn't in ModelOpt's") + print("recognized-model registry, so we call the per-weight export helper ourselves)...") + exported = _force_export_quantized_weights(backbone, dtype) + print(f" -> {exported} weights converted to FP8 in {label}") + if exported == 0: + raise SystemExit( + f"No quantized weights were exported in {label}. Calibration may have skipped every " + "layer (check the disable_quantizer regex) or `mtq.quantize` did not actually wrap " + "any weight quantizers." + ) + + +def main() -> None: + args = _build_parser().parse_args() + if not torch.cuda.is_available(): + raise SystemExit("CUDA is required for ModelOpt FP8 quantization.") + + mtq = _require_modelopt() + model_path, output_dir = _ensure_paths(args) + dtype = _select_dtype(args.dtype) + weight_block_size = _parse_block_size(args.weight_block_size) + + ref_images = _load_reference_images(args.reference_images) + samples = _build_calib_samples(args, ref_images) + n_r2v = sum(1 for _, ref in samples if ref is not None) + n_t2v = len(samples) - n_r2v + + print("Quantization plan:") + print(f" input: {args.model}") + print(f" output: {output_dir}") + print(f" dtype: {dtype}") + print(f" height/width: {args.height}x{args.width}") + print(f" num_frames: {args.num_frames}") + print(f" calib_size: {len(samples)} (T2V={n_t2v}, R2V={n_r2v})") + print(f" calib_steps: {args.calib_steps}") + print(f" quantize_mha: {args.quantize_mha}") + print(f" reference imgs: {len(ref_images)}") + print( + f" weight strategy: {'block-wise ' + str(weight_block_size) if weight_block_size else 'per-tensor (default)'}" + ) + + pipe = _load_pipeline(model_path, dtype) + is_dual = getattr(pipe, "transformer_2", None) is not None + if is_dual: + print(" detected dual-transformer VACE variant (transformer + transformer_2)") + + # Production boundary_ratio captured from model_index.json so pass 2 can + # restore it after a --calib-boundary-ratio override on pass 1. + production_boundary = pipe.config.get("boundary_ratio") if is_dual else None + + quant_config = copy.deepcopy(mtq.FP8_DEFAULT_CFG) + if weight_block_size is not None: + quant_config["quant_cfg"]["*weight_quantizer"] = { + "num_bits": (4, 3), + "block_sizes": {-1: weight_block_size[1], -2: weight_block_size[0]}, + } + print( + f" -> overriding weight quantizer with block_sizes={weight_block_size} " + f"({weight_block_size[0]}x{weight_block_size[1]} tiles)" + ) + + forward_loop = _build_forward_loop(pipe, args, samples) + + # Single-transformer VACE does one pass; dual-transformer (Wan2.2-VACE-A14B) + # does two. Calibration must complete for BOTH backbones BEFORE any + # _force_export call — transformer's weights must still be BF16 at that point. + if is_dual and args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=args.calib_boundary_ratio) + print( + f"\n pass 1 boundary_ratio: {args.calib_boundary_ratio} " + f"(override of production {production_boundary} for transformer sample boost)" + ) + + pipe.transformer = _calibrate( + pipe.transformer, + "transformer", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, + ) + if is_dual: + if args.calib_boundary_ratio is not None: + pipe.register_to_config(boundary_ratio=production_boundary) + print( + f"\n pass 2 boundary_ratio: {production_boundary} " + "(restored to production for transformer_2 in-distribution calibration)" + ) + pipe.transformer_2 = _calibrate( + pipe.transformer_2, + "transformer_2", + mtq=mtq, + quant_config=quant_config, + forward_loop=forward_loop, + quantize_mha=args.quantize_mha, + ) + + _force_export(pipe.transformer, "transformer", dtype) + if is_dual: + _force_export(pipe.transformer_2, "transformer_2", dtype) + + print("\nSaving pipeline with FP8 transformer(s)...") + _save_pipeline_with_fp8_transformers(pipe, model_path, output_dir) + _patch_quant_config(output_dir, subfolder="transformer", weight_block_size=weight_block_size) + if is_dual: + _patch_quant_config(output_dir, subfolder="transformer_2", weight_block_size=weight_block_size) + print(f"Saved to: {output_dir}") + _summarize_export(output_dir, subfolder="transformer") + if is_dual: + _summarize_export(output_dir, subfolder="transformer_2") + + print("\nNext: validate the checkpoint with vllm-omni:") + if n_r2v > 0: + print( + " python examples/offline_inference/vace/vace_video_generation.py \\\n" + " --mode r2v \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'The subject from the reference image walks through a snowy forest at dusk.' \\\n" + " --image \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 \\\n" + " --output outputs/wan_vace_r2v_modelopt_fp8.mp4" + ) + print( + "\n (T2V also works with this checkpoint — drop --mode r2v / --image and pass a " + "plain prompt; vace_blocks were calibrated on both zero- and real-conditioning samples.)" + ) + else: + print( + " python examples/offline_inference/vace/vace_video_generation.py \\\n" + " --mode t2v \\\n" + f" --model {output_dir} \\\n" + " --quantization fp8 \\\n" + " --prompt 'A dog running across a field of golden wheat.' \\\n" + f" --height {args.height} --width {args.width} --num-frames {args.num_frames} \\\n" + " --num-inference-steps 30 --guidance-scale 5.0 \\\n" + " --output outputs/wan_vace_modelopt_fp8.mp4" + ) + print( + "\n (R2V/I2V inference will still work but vace_blocks' amax was calibrated on " + "zero vace_context only — re-run quantization with --reference-images for tighter " + "R2V scales.)" + ) + print( + "\n (--quantization fp8 is auto-upgraded to ModelOpt FP8 at runtime because the " + "checkpoint's config.json has modelopt metadata.)" + ) + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py index b3580c70f0e..1863045877c 100644 --- a/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py +++ b/vllm_omni/diffusion/model_loader/checkpoint_adapters/modelopt_fp8.py @@ -54,9 +54,12 @@ def is_compatible( @staticmethod def _is_transformer_source(source: object) -> bool: - if getattr(source, "subfolder", None) == "transformer": + # Wan2.2 MoE variants (T2V/I2V-A14B) load a second backbone under + # `transformer_2/` with prefix `transformer_2.` — accept both subfolders. + if getattr(source, "subfolder", None) in ("transformer", "transformer_2"): return True - return str(getattr(source, "prefix", "")).startswith("transformer.") + prefix = str(getattr(source, "prefix", "")) + return prefix.startswith("transformer.") or prefix.startswith("transformer_2.") @staticmethod def _is_checkpoint_quant_config(quant_config: object | None) -> bool: