Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions docs/user_guide/quantization/fp8.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ guide. FP8 on Ampere may use a weight-only path where available.

### Diffusion Model (Qwen-Image, Wan2.2)

| Model | HF models | Online | Pre-calibrated | Recommendation | `ignored_layers` |
|-------|-----------|:-------:|:------:|----------------|------------------|
| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Yes | Yes | Skip sensitive image-stream MLPs when quality regresses | `img_mlp` |
| Wan2.2 | Wan2.2 diffusion pipelines | Not validated | Not validated | Validate against BF16 before documenting as supported | TBD |
| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | Yes | Yes | All layers | None |
| FLUX.1 | `black-forest-labs/FLUX.1-dev`, `black-forest-labs/FLUX.1-schnell` | Yes | Yes | All layers | None |
| FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B` | Yes | Yes | All layers | None |
| HunyuanImage-3.0 | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | Yes | Yes | All layers; use the Hunyuan stage config for multi-stage runs | None |
| HunyuanVideo-1.5 | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `720p_t2v`, `480p_i2v` | Yes | Yes | All layers | None |
| Model | HF models | Online | Pre-calibrated | Recommendation | `ignored_layers` | Text-Encoder quantization |
|-------|-----------|:-------:|:------:|----------------|------------------|------------------|
| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Yes | Yes | Skip sensitive image-stream MLPs when quality regresses | `img_mlp` | |
| Wan2.2 | Wan2.2 diffusion pipelines | Not validated | Not validated | Validate against BF16 before documenting as supported | TBD | |
| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | Yes | Yes | All layers | None | ✅︎ |
| FLUX.1 | `black-forest-labs/FLUX.1-dev`, `black-forest-labs/FLUX.1-schnell` | Yes | Yes | All layers | None | |
| FLUX.2-klein | `black-forest-labs/FLUX.2-klein-4B` | Yes | Yes | All layers | None | |
| HunyuanImage-3.0 | `tencent/HunyuanImage-3.0`, `tencent/HunyuanImage-3.0-Instruct` | Yes | Yes | All layers; use the Hunyuan stage config for multi-stage runs | None | |
| HunyuanVideo-1.5 | `hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v`, `720p_t2v`, `480p_i2v` | Yes | Yes | All layers | None | |

### Multi-Stage Omni/TTS Model (Qwen3-Omni, Qwen3-TTS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def from_pretrained(cls, *args: Any, **kwargs: Any):
model.init_distributed()
return model

@classmethod
def from_config(cls, *args: Any, **kwargs: Any):
model = super().from_config(*args, **kwargs)
model.init_distributed()
return model

def tile_split(self, z: torch.Tensor) -> tuple[list[TileTask], GridSpec]:
# mostly copy from AutoencoderKL
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
Expand Down
23 changes: 20 additions & 3 deletions vllm_omni/diffusion/model_loader/diffusers_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
multi_thread_safetensors_weights_iterator,
safetensors_weights_iterator,
)
from vllm.transformers_utils.repo_utils import file_exists
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import set_default_torch_dtype

Expand All @@ -49,6 +50,8 @@ def _natural_sort_key(filepath: str) -> list:

MODEL_INDEX = "model_index.json"
DIFFUSION_MODEL_WEIGHTS_INDEX = "diffusion_pytorch_model.safetensors.index.json"
TRANSFORMER_WEIGHTS_INDEX = "model.safetensors.index.json"
INDEX_FILES = [DIFFUSION_MODEL_WEIGHTS_INDEX, TRANSFORMER_WEIGHTS_INDEX]


class DiffusersPipelineLoader:
Expand Down Expand Up @@ -99,8 +102,22 @@ def _prepare_weights(
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = DIFFUSION_MODEL_WEIGHTS_INDEX
index_file_with_subfolder = f"{subfolder}/{index_file}" if subfolder else index_file
possible_index_files = [
f"{subfolder}/{index_file}" if subfolder is not None else index_file for index_file in INDEX_FILES
]
available_index_file = list(
filter(lambda f: file_exists(model_name_or_path, f, revision=revision), possible_index_files)
)
if len(available_index_file) > 1:
raise ValueError(
f"Multiple index files found in {model_name_or_path} with subfolder {subfolder}: {available_index_file}"
)
index_file_with_subfolder = available_index_file[0] if len(available_index_file) == 1 else None
index_file = (
index_file_with_subfolder.split("/")[-1]
if index_file_with_subfolder and subfolder is not None
else index_file_with_subfolder
)

# only hf is supported currently
if load_format == "auto":
Expand Down Expand Up @@ -163,7 +180,7 @@ def _prepare_weights(
hf_weights_files = filter_duplicate_safetensors_files(
Comment thread
Isotr0py marked this conversation as resolved.
hf_weights_files,
filter_folder,
index_file,
index_file or "",
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
Expand Down
126 changes: 126 additions & 0 deletions vllm_omni/diffusion/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,132 @@

import json
import os
from typing import TYPE_CHECKING, Literal

import torch
import torch.nn as nn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.models.transformers.utils import init_on_device_without_buffers
from vllm.model_executor.models.utils import maybe_prefix

from vllm_omni.diffusion.data import OmniDiffusionConfig
from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers

if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
)


Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"]


def replace_linear_class(
linear: nn.Linear,
style: Style = "replicate",
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.

Args:
linear: `nn.Linear` to be replaced.
style: Tensor parallel style of the new linear, e.g. "colwise".
quant_config: Quantization config for the new linear.
Returns:
The new linear.
"""

if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")

vllm_linear_maps = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {"gather_output": True}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}),
"replicate": (ReplicatedLinear, {}),
}
vllm_linear_cls, vllm_linear_kwargs = vllm_linear_maps[style]

return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
Comment thread
Isotr0py marked this conversation as resolved.
quant_config=quant_config,
prefix=prefix,
return_bias=False,
**vllm_linear_kwargs,
)


def recursive_replace_linear(model: nn.Module, od_config: OmniDiffusionConfig):
"""Recursively replace modules in the model as needed.
Currently, this replaces:
- `nn.Linear` with vLLM's tensor parallel linear classes
"""
# Prefix the patterns because we always start from `self.model`
quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config)

Comment thread
Isotr0py marked this conversation as resolved.
def _recursive_replace(module: nn.Module, prefix: str):
for child_name, child_module in module.named_children():
new_module = child_module
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice utility! One small thing I noticed: recursive_replace_linear always sets style = "replicate" for every nn.Linear in the model. This works correctly for FP8 quantization today, but if this utility is later reused for tensor-parallel text encoders, we'd need per-layer style selection.

Would it be worth accepting an optional style_map: dict[str, Style] parameter (defaulting to None = all replicate) to make this future-proof? Not a blocker at all -- just thinking about reusability since the function name suggests general-purpose use.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if this utility is later reused for tensor-parallel text encoders, we'd need per-layer style selection.

Would it be worth accepting an optional style_map: dict[str, Style] parameter (defaulting to None = all replicate) to make this future-proof?

We can reuse tp_plan from Transformers model like vLLM's Transformers backend, but I would like to leave it to a following PR because it can make things quite complicated:
https://github.com/vllm-project/vllm/blob/bebfe55b1c17c2e0fedb1b402df1dddfc1a04684/vllm/model_executor/models/transformers/base.py#L285-L296

qual_name = maybe_prefix(prefix, child_name)
# Replace modules as needed
if isinstance(child_module, nn.Linear):
style = "replicate"
new_module = replace_linear_class(child_module, style, quant_config, prefix=qual_name)
else:
_recursive_replace(child_module, prefix=qual_name)
if new_module is not child_module:
setattr(module, child_name, new_module)

_recursive_replace(model, prefix="")


def init_parameters(
module: nn.Module,
dtype: torch.dtype | None,
device: torch.device | None = None,
):
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(
param.data,
Comment thread
Isotr0py marked this conversation as resolved.
dtype=dtype,
device=device,
),
requires_grad=param.requires_grad,
)
setattr(module, name, new_param)
for child in module.children():
init_parameters(child, dtype, device)


def create_transformers_model(
auto_cls: _BaseAutoModelClass,
od_config: OmniDiffusionConfig,
hf_config: PretrainedConfig,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
) -> PreTrainedModel:
"""Create a HuggingFace model using the given auto class and model name."""
dtype = dtype or od_config.dtype
device = device or torch.get_default_device()
with init_on_device_without_buffers("meta"):
model = auto_cls.from_config(hf_config)
recursive_replace_linear(model, od_config)
init_parameters(model, dtype=dtype, device=device)
return model


def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict:
Expand Down
34 changes: 27 additions & 7 deletions vllm_omni/diffusion/models/z_image/pipeline_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from transformers import AutoModel, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl import DistributedAutoencoderKL
from vllm_omni.diffusion.distributed.utils import get_local_device
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
from vllm_omni.diffusion.model_loader.hub_prefetch import prefetch_subfolders
from vllm_omni.diffusion.models.utils import create_transformers_model
from vllm_omni.diffusion.models.z_image.z_image_transformer import (
ZImageTransformer2DModel,
)
Expand Down Expand Up @@ -169,13 +170,25 @@ def __init__(
super().__init__()
self.od_config = od_config
self.weights_sources = [
DiffusersPipelineLoader.ComponentSource(
model_or_path=od_config.model,
subfolder="text_encoder",
revision=od_config.revision,
prefix="text_encoder.",
),
DiffusersPipelineLoader.ComponentSource(
model_or_path=od_config.model,
subfolder="transformer",
revision=None,
revision=od_config.revision,
prefix="transformer.",
fall_back_to_pt=True,
)
),
DiffusersPipelineLoader.ComponentSource(
model_or_path=od_config.model,
subfolder="vae",
revision=od_config.revision,
prefix="vae.",
),
]
self._execution_device = get_local_device()
model = od_config.model
Expand All @@ -192,12 +205,19 @@ def __init__(
model, subfolder="scheduler", local_files_only=local_files_only
)

self.text_encoder = AutoModel.from_pretrained(
text_encoder_config = AutoConfig.from_pretrained(
model, subfolder="text_encoder", local_files_only=local_files_only
)
self.text_encoder = create_transformers_model(
AutoModelForCausalLM,
od_config,
hf_config=text_encoder_config,
).to(self._execution_device)
self.vae = DistributedAutoencoderKL.from_pretrained(
model, subfolder="vae", local_files_only=local_files_only
).to(self._execution_device)
if text_encoder_config.tie_word_embeddings:
self.text_encoder.lm_head.weight = self.text_encoder.get_input_embeddings().weight

vae_config = DistributedAutoencoderKL.load_config(model, subfolder="vae", local_files_only=local_files_only)
self.vae = DistributedAutoencoderKL.from_config(vae_config).to(self._execution_device)
self.transformer = ZImageTransformer2DModel(quant_config=od_config.quantization_config)
self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)

Expand Down
Loading