diff --git a/docs/user_guide/quantization/fp8.md b/docs/user_guide/quantization/fp8.md index e89bc76ca77..7373a39ffb4 100644 --- a/docs/user_guide/quantization/fp8.md +++ b/docs/user_guide/quantization/fp8.md @@ -32,15 +32,15 @@ guide. FP8 on Ampere may use a weight-only path where available. ### Diffusion Model (Qwen-Image, Wan2.2) -| Model | HF models | Online | Pre-calibrated | Recommendation | `ignored_layers` | -|-------|-----------|:-------:|:------:|----------------|------------------| -| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Yes | Yes | Skip sensitive image-stream MLPs when quality regresses | `img_mlp` | -| Wan2.2 | Wan2.2 diffusion pipelines | Not validated | Not validated | Validate against BF16 before documenting as supported | TBD | -| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | Yes | Yes | All layers | None | -| FLUX.1 | `black-forest-labs/FLUX.1-dev`, `black-forest-labs/FLUX.1-schnell` | Yes | Yes | All layers | None | -| FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B` | Yes | Yes | All layers | None | -| HunyuanImage-3.0 | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | Yes | Yes | All layers; use the Hunyuan stage config for multi-stage runs | None | -| HunyuanVideo-1.5 | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `720p_t2v`, `480p_i2v` | Yes | Yes | All layers | None | +| Model | HF models | Online | Pre-calibrated | Recommendation | `ignored_layers` | Text-Encoder quantization | +|-------|-----------|:-------:|:------:|----------------|------------------|------------------| +| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Yes | Yes | Skip sensitive image-stream MLPs when quality regresses | `img_mlp` | | +| Wan2.2 | Wan2.2 diffusion pipelines | Not validated | Not validated | Validate against BF16 before documenting as supported | TBD | | +| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | Yes | Yes | All layers | None | ✅︎ | +| FLUX.1 | `black-forest-labs/FLUX.1-dev`, `black-forest-labs/FLUX.1-schnell` | Yes | Yes | All layers | None | | +| FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B` | Yes | Yes | All layers | None | | +| HunyuanImage-3.0 | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | Yes | Yes | All layers; use the Hunyuan stage config for multi-stage runs | None | | +| HunyuanVideo-1.5 | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `720p_t2v`, `480p_i2v` | Yes | Yes | All layers | None | | ### Multi-Stage Omni/TTS Model (Qwen3-Omni, Qwen3-TTS) diff --git a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py index 0084719a8ab..1045a6135ff 100644 --- a/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py +++ b/vllm_omni/diffusion/distributed/autoencoders/autoencoder_kl.py @@ -28,6 +28,12 @@ def from_pretrained(cls, *args: Any, **kwargs: Any): model.init_distributed() return model + @classmethod + def from_config(cls, *args: Any, **kwargs: Any): + model = super().from_config(*args, **kwargs) + model.init_distributed() + return model + def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]: # mostly copy from AutoencoderKL overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 91f3574b185..4c6fb070997 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -26,6 +26,7 @@ multi_thread_safetensors_weights_iterator, safetensors_weights_iterator, ) +from vllm.transformers_utils.repo_utils import file_exists from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import set_default_torch_dtype @@ -49,6 +50,8 @@ def _natural_sort_key(filepath: str) -> list: MODEL_INDEX = "model_index.json" DIFFUSION_MODEL_WEIGHTS_INDEX = "diffusion_pytorch_model.safetensors.index.json" +TRANSFORMER_WEIGHTS_INDEX = "model.safetensors.index.json" +INDEX_FILES = [DIFFUSION_MODEL_WEIGHTS_INDEX, TRANSFORMER_WEIGHTS_INDEX] class DiffusersPipelineLoader: @@ -99,8 +102,17 @@ def _prepare_weights( is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format use_safetensors = False - index_file = DIFFUSION_MODEL_WEIGHTS_INDEX - index_file_with_subfolder = f"{subfolder}/{index_file}" if subfolder else index_file + possible_index_files = [ + f"{subfolder}/{index_file}" if subfolder is not None else index_file for index_file in INDEX_FILES + ] + available_index_file = [ + f for f in possible_index_files if file_exists(model_name_or_path, f, revision=revision) + ] + if len(available_index_file) > 1: + raise ValueError( + f"Multiple index files found in {model_name_or_path} with subfolder {subfolder}: {available_index_file}" + ) + index_file = available_index_file[0] if available_index_file else "" # only hf is supported currently if load_format == "auto": @@ -118,20 +130,21 @@ def _prepare_weights( if allow_patterns_overrides is not None: allow_patterns = allow_patterns_overrides - if subfolder is not None: - allow_patterns = [f"{subfolder}/{pattern}" for pattern in allow_patterns] - if not is_local: hf_folder = download_weights_from_hf( model_name_or_path, self.load_config.download_dir, allow_patterns, revision, + subfolder=subfolder, ignore_patterns=self.load_config.ignore_patterns, ) else: hf_folder = model_name_or_path + if subfolder is not None: + hf_folder = os.path.join(hf_folder, subfolder) + hf_weights_files: list[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) @@ -149,22 +162,12 @@ def _prepare_weights( if not is_local: download_safetensors_index_file_from_hf( model_name_or_path, - index_file_with_subfolder, - self.load_config.download_dir, - revision, + index_file, + cache_dir=self.load_config.download_dir, + subfolder=subfolder, + revision=revision, ) - # Some diffusers pipelines keep component weights under a - # subfolder (e.g. "transformer/") and the corresponding index file - # uses filenames relative to that subfolder. vLLM's - # `filter_duplicate_safetensors_files` expects weight_map entries - # to be relative to the `hf_folder` we pass in, so we point it to - # the component subfolder to avoid filtering out all shards. - filter_folder = os.path.join(hf_folder, subfolder) if subfolder is not None else hf_folder - hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, - filter_folder, - index_file, - ) + hf_weights_files = filter_duplicate_safetensors_files(hf_weights_files, hf_folder, index_file) else: hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) diff --git a/vllm_omni/diffusion/models/utils.py b/vllm_omni/diffusion/models/utils.py index ba0d8dda20c..122646219ff 100644 --- a/vllm_omni/diffusion/models/utils.py +++ b/vllm_omni/diffusion/models/utils.py @@ -5,6 +5,132 @@ import json import os +from typing import TYPE_CHECKING, Literal + +import torch +import torch.nn as nn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.models.transformers.utils import init_on_device_without_buffers +from vllm.model_executor.models.utils import maybe_prefix + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.quantization import build_quant_config + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + from transformers.models.auto.auto_factory import _BaseAutoModelClass + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + + +Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] + + +def replace_linear_class( + linear: nn.Linear, + style: Style = "replicate", + quant_config: QuantizationConfig | None = None, + *, + prefix: str = "", +) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear: + """ + Replace nn.Linear with one of vLLM's tensor parallel linear classes. + + Args: + linear: `nn.Linear` to be replaced. + style: Tensor parallel style of the new linear, e.g. "colwise". + quant_config: Quantization config for the new linear. + Returns: + The new linear. + """ + + if not isinstance(style, str): + raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") + + vllm_linear_maps = { + "colwise": (ColumnParallelLinear, {}), + "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), + "rowwise": (RowParallelLinear, {}), + "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), + "replicate": (ReplicatedLinear, {}), + } + vllm_linear_cls, vllm_linear_kwargs = vllm_linear_maps[style] + + return vllm_linear_cls( + input_size=linear.in_features, + output_size=linear.out_features, + bias=linear.bias is not None, + quant_config=quant_config, + prefix=prefix, + return_bias=False, + **vllm_linear_kwargs, + ) + + +def recursive_replace_linear(model: nn.Module, od_config: OmniDiffusionConfig): + """Recursively replace modules in the model as needed. + Currently, this replaces: + - `nn.Linear` with vLLM's tensor parallel linear classes + """ + # Prefix the patterns because we always start from `self.model` + quant_config = build_quant_config(od_config.quantization_config) + + def _recursive_replace(module: nn.Module, prefix: str): + for child_name, child_module in module.named_children(): + new_module = child_module + qual_name = maybe_prefix(prefix, child_name) + # Replace modules as needed + if isinstance(child_module, nn.Linear): + style = "replicate" + new_module = replace_linear_class(child_module, style, quant_config, prefix=qual_name) + else: + _recursive_replace(child_module, prefix=qual_name) + if new_module is not child_module: + setattr(module, child_name, new_module) + + _recursive_replace(model, prefix="") + + +def init_parameters( + module: nn.Module, + dtype: torch.dtype | None, + device: torch.device | None = None, +): + for name, param in module.named_parameters(recurse=False): + if param.device == torch.device("meta"): + new_param = nn.Parameter( + torch.empty_like( + param.data, + dtype=dtype, + device=device, + ), + requires_grad=param.requires_grad, + ) + setattr(module, name, new_param) + for child in module.children(): + init_parameters(child, dtype, device) + + +def create_transformers_model( + auto_cls: _BaseAutoModelClass, + od_config: OmniDiffusionConfig, + hf_config: PretrainedConfig, + dtype: torch.dtype | None = None, + device: torch.device | None = None, +) -> PreTrainedModel: + """Create a HuggingFace model using the given auto class and model name.""" + dtype = dtype or od_config.dtype + device = device or torch.get_default_device() + with init_on_device_without_buffers("meta"): + model = auto_cls.from_config(hf_config) + recursive_replace_linear(model, od_config) + init_parameters(model, dtype=dtype, device=device) + return model def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict: diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 296f1e57f9e..5ead64566a2 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -28,7 +28,7 @@ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor -from transformers import AutoModel, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig @@ -36,6 +36,7 @@ from vllm_omni.diffusion.distributed.utils import get_local_device from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.model_loader.hub_prefetch import prefetch_subfolders +from vllm_omni.diffusion.models.utils import create_transformers_model from vllm_omni.diffusion.models.z_image.z_image_transformer import ( ZImageTransformer2DModel, ) @@ -169,13 +170,25 @@ def __init__( super().__init__() self.od_config = od_config self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="text_encoder", + revision=od_config.revision, + prefix="text_encoder.", + ), DiffusersPipelineLoader.ComponentSource( model_or_path=od_config.model, subfolder="transformer", - revision=None, + revision=od_config.revision, prefix="transformer.", fall_back_to_pt=True, - ) + ), + DiffusersPipelineLoader.ComponentSource( + model_or_path=od_config.model, + subfolder="vae", + revision=od_config.revision, + prefix="vae.", + ), ] self._execution_device = get_local_device() model = od_config.model @@ -192,12 +205,19 @@ def __init__( model, subfolder="scheduler", local_files_only=local_files_only ) - self.text_encoder = AutoModel.from_pretrained( + text_encoder_config = AutoConfig.from_pretrained( model, subfolder="text_encoder", local_files_only=local_files_only + ) + self.text_encoder = create_transformers_model( + AutoModelForCausalLM, + od_config, + hf_config=text_encoder_config, ).to(self._execution_device) - self.vae = DistributedAutoencoderKL.from_pretrained( - model, subfolder="vae", local_files_only=local_files_only - ).to(self._execution_device) + if text_encoder_config.tie_word_embeddings: + self.text_encoder.lm_head.weight = self.text_encoder.get_input_embeddings().weight + + vae_config = DistributedAutoencoderKL.load_config(model, subfolder="vae", local_files_only=local_files_only) + self.vae = DistributedAutoencoderKL.from_config(vae_config).to(self._execution_device) self.transformer = ZImageTransformer2DModel(quant_config=od_config.quantization_config) self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)