diff --git a/docs/.nav.yml b/docs/.nav.yml index 0740142090..7fe8f988a1 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -46,6 +46,7 @@ nav: - Quantization: - Overview: user_guide/diffusion/quantization/overview.md - FP8: user_guide/diffusion/quantization/fp8.md + - GGUF: user_guide/diffusion/quantization/gguf.md - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md - LoRA: user_guide/diffusion/lora.md diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index b4eb441292..6cef7cfbd2 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -124,6 +124,7 @@ def add_parser(self, name, **kwargs): "logger": logger, "DummySubparsers": DummySubparsers, "argparse": __import__("argparse"), + "json": __import__("json"), "DESCRIPTION": DESCRIPTION, } exec(code, exec_globals, local_vars) diff --git a/docs/user_guide/diffusion/quantization/gguf.md b/docs/user_guide/diffusion/quantization/gguf.md new file mode 100644 index 0000000000..025180e185 --- /dev/null +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -0,0 +1,185 @@ +# GGUF Quantization + +## Goals +1. Reuse vLLM quantization configs and weight loaders as much as possible. +2. Add native GGUF support to diffusion transformers without changing model definitions. +3. Keep user-facing knobs minimal and consistent across offline and online flows. + +## Scope +1. Models: Z-Image, and Flux2-klein. +2. Components: diffusion transformer weights, loader paths, and quantization configs. +3. Modes: native GGUF (transformer-only weights). + +## Architecture Overview +1. `OmniDiffusionConfig` accepts `quantization` or `quantization_config`. +2. Diffusion quantization wrapper (`DiffusionGgufConfig`) produces vLLM `QuantizationConfig` objects for linear layers. +3. `DiffusersPipelineLoader` branches on quantization method and loads either HF weights or GGUF weights for the transformer. +4. GGUF transformer loading is routed through model-specific adapters (e.g., Flux2Klein). +5. vLLM GGUF path uses `GGUFConfig` and `GGUFLinearMethod` for matmul. + +## Call Chain (Offline) +``` +CLI (examples/offline_inference/text_to_image/text_to_image.py) + | + v +Omni (vllm_omni/entrypoints/omni.py) + | + v +OmniStage (diffusion) + | + v +DiffusionWorker + | + v +DiffusionModelRunner + | + v +DiffusersPipelineLoader + | + v +Pipeline.forward (Flux2/Qwen/Z-Image) + | + v +DiffusionEngine + | + v +OmniRequestOutput + | + v +Client (saved PNG) +``` + +## Call Chain (Online) +``` +Client + | + | POST /v1/images/generations + v +APIServer (vllm_omni/entrypoints/openai/api_server.py) + | + v +_generate_with_async_omni + | + v +AsyncOmni + | + v +DiffusionEngine + | + v +OmniRequestOutput + | + v +encode_image_base64 + | + v +ImageGenerationResponse + | + v +Client +``` + +## Call Chain (GGUF Operator Path) +``` +Pipeline.forward (Flux2/Qwen/Z-Image) + | + v +Transformer blocks + | + v +QKVParallelLinear / ColumnParallelLinear / RowParallelLinear + | + v +LinearBase.forward + | + v +QuantMethod.apply (GGUFLinearMethod.apply) + | + v +fused_mul_mat_gguf + | + v +_fused_mul_mat_gguf (custom op) + | + v +ops.ggml_dequantize + | + v +x @ weight.T +``` + +## GGUF Weight Loading Path (Transformer-Only) +1. `DiffusersPipelineLoader.load_model` detects `quantization_config.method == "gguf"`. +2. `gguf_model` is resolved as one of: local file, `repo/file.gguf`, or `repo:quant_type`. +3. GGUF weights are routed through adapters in `vllm_omni/diffusion/model_loader/gguf_adapters/`. +4. Name mapping is applied per-architecture (Z-Image, Flux2Klein). +5. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. + +## GGUF Adapter Design +1. `GGUFAdapter` is an abstract base class for model-specific adapters. +2. `Flux2KleinGGUFAdapter` implements Flux2-Klein remapping + qkv split + adaLN swap. +3. `ZImageGGUFAdapter` implements Z-Image qkv + ffn shard handling and linear qweight routing. +4. `get_gguf_adapter(...)` strictly selects by model class/config; unsupported models raise an error (no fallback adapter). + +Adapter paths: +- Base: `vllm_omni/diffusion/model_loader/gguf_adapters/base.py` +- Z-Image: `vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py` +- Flux2-Klein: `vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py` + +## User Usage (Offline) + +### Baseline BF16 +```bash +python examples/offline_inference/text_to_image/text_to_image.py \ + --model /workspace/models/black-forest-labs/FLUX.2-klein-4B \ + --prompt "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture" \ + --height 768 \ + --width 1360 \ + --seed 42 \ + --cfg_scale 4.0 \ + --num_images_per_prompt 1 \ + --num_inference_steps 4 \ + --output outputs/flux2_klein_4b.png +``` + +### Native GGUF (Transformer Only) +```bash +python examples/offline_inference/text_to_image/text_to_image.py \ + --model /workspace/models/black-forest-labs/FLUX.2-klein-4B \ + --gguf-model "/workspace/models/unsloth/FLUX.2-klein-4B-GGUF/flux-2-klein-4b-Q8_0.gguf" \ + --quantization gguf \ + --prompt "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture" \ + --height 768 \ + --width 1360 \ + --seed 42 \ + --cfg_scale 4.0 \ + --num_images_per_prompt 1 \ + --num_inference_steps 4 \ + --output outputs/flux2_klein_4b_gguf.png +``` + +Notes for GGUF: +1. Many GGUF repos do not ship `model_index.json` and configs. Use the base repo for `--model` and only pass the GGUF file via `--gguf-model`. +2. `gguf_model` supports local path, `repo/file.gguf`, or `repo:quant_type`. + +## User Usage (Online) + +### Start Server (Native GGUF via CLI) +```bash +vllm serve /workspace/models/black-forest-labs/FLUX.2-klein-4B \ + --omni \ + --port 8000 \ + --quantization-config '{"method":"gguf","gguf_model":"/workspace/models/unsloth/FLUX.2-klein-4B-GGUF/flux-2-klein-4b-Q8_0.gguf"}' +``` + +### Online Request (Images API) +```bash +curl -X POST http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "a dragon laying over the spine of the Green Mountains of Vermont", + "size": "1024x1024", + "seed": 42, + "num_inference_steps": 4 + }' +``` diff --git a/docs/user_guide/diffusion/quantization/overview.md b/docs/user_guide/diffusion/quantization/overview.md index 7dede292fc..e4ce69677c 100644 --- a/docs/user_guide/diffusion/quantization/overview.md +++ b/docs/user_guide/diffusion/quantization/overview.md @@ -7,6 +7,7 @@ vLLM-Omni supports quantization of DiT linear layers to reduce memory usage and | Method | Guide | |--------|-------| | FP8 | [FP8](fp8.md) | +| GGUF | [GGUF](gguf.md) | ## Device Compatibility diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 78865388fa..3829716068 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -132,10 +132,18 @@ def parse_args() -> argparse.Namespace: "--quantization", type=str, default=None, - choices=["fp8"], - help="Quantization method for the transformer. " - "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). " - "Default: None (no quantization, uses BF16).", + choices=["fp8", "gguf"], + help=( + "Quantization method for the transformer. " + "Options: 'fp8' (FP8 W8A8), 'gguf' (GGUF quantized weights). " + "Default: None (no quantization, uses BF16)." + ), + ) + parser.add_argument( + "--gguf-model", + type=str, + default=None, + help=("GGUF file path or HF reference for transformer weights. Required when --quantization gguf is set."), ) parser.add_argument( "--ignored-layers", @@ -265,7 +273,14 @@ def main(): # ignored_layers is specified so the list flows through OmniDiffusionConfig quant_kwargs: dict[str, Any] = {} ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None - if args.quantization and ignored_layers: + if args.quantization == "gguf": + if not args.gguf_model: + raise ValueError("--gguf-model is required when --quantization gguf is set.") + quant_kwargs["quantization_config"] = { + "method": "gguf", + "gguf_model": args.gguf_model, + } + elif args.quantization and ignored_layers: quant_kwargs["quantization_config"] = { "method": args.quantization, "ignored_layers": ignored_layers, diff --git a/tests/diffusion/test_diffusers_loader.py b/tests/diffusion/test_diffusers_loader.py new file mode 100644 index 0000000000..3f63960274 --- /dev/null +++ b/tests/diffusion/test_diffusers_loader.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +class _DummyPipelineModel(nn.Module): + def __init__(self, *, source_prefix: str): + super().__init__() + self.transformer = nn.Linear(2, 2, bias=False) + self.vae = nn.Linear(2, 2, bias=False) + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path="dummy", + subfolder="transformer", + revision=None, + prefix=source_prefix, + fall_back_to_pt=True, + ) + ] + + def load_weights(self, weights): + params = dict(self.named_parameters()) + loaded: set[str] = set() + for name, tensor in weights: + if name not in params: + continue + params[name].data.copy_(tensor.to(dtype=params[name].dtype)) + loaded.add(name) + return loaded + + +def _make_loader_with_weights(weight_names: list[str]) -> DiffusersPipelineLoader: + loader = object.__new__(DiffusersPipelineLoader) + loader.counter_before_loading_weights = 0.0 + loader.counter_after_loading_weights = 0.0 + + def _iter_weights(_model): + for name in weight_names: + yield name, torch.zeros((2, 2)) + + loader.get_all_weights = _iter_weights # type: ignore[assignment] + return loader + + +def test_strict_check_only_validates_source_prefix_parameters(): + model = _DummyPipelineModel(source_prefix="transformer.") + loader = _make_loader_with_weights(["transformer.weight"]) + + # Should not require VAE parameters because they are outside weights_sources. + loader.load_weights(model) + + +def test_strict_check_raises_when_source_parameters_are_missing(): + model = _DummyPipelineModel(source_prefix="transformer.") + loader = _make_loader_with_weights([]) + + with pytest.raises(ValueError, match="transformer.weight"): + loader.load_weights(model) + + +def test_empty_source_prefix_keeps_full_model_strict_check(): + model = _DummyPipelineModel(source_prefix="") + loader = _make_loader_with_weights(["transformer.weight"]) + + with pytest.raises(ValueError, match="vae.weight"): + loader.load_weights(model) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 703dba2585..bdf877a6e0 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -6,7 +6,7 @@ import random from collections.abc import Callable, Mapping from dataclasses import dataclass, field, fields -from typing import Any +from typing import TYPE_CHECKING, Any import torch from pydantic import model_validator @@ -20,6 +20,12 @@ ) from vllm_omni.diffusion.utils.network_utils import is_port_available +if TYPE_CHECKING: + from vllm_omni.diffusion.quantization import DiffusionQuantizationConfig + +# Import after TYPE_CHECKING to avoid circular imports at runtime +# The actual import is deferred to __post_init__ to avoid import order issues + logger = init_logger(__name__) @@ -527,8 +533,12 @@ def __post_init__(self): # If it's neither dict nor DiffusionCacheConfig, convert to empty config self.cache_config = DiffusionCacheConfig() - # Convert quantization config + # Convert quantization config (deferred import to avoid circular imports) if self.quantization is not None or self.quantization_config is not None: + from vllm_omni.diffusion.quantization import ( + DiffusionQuantizationConfig, + ) + # Handle dict or DictConfig (from OmegaConf) - use Mapping for broader compatibility if isinstance(self.quantization_config, Mapping): # Convert DictConfig to dict if needed (OmegaConf compatibility) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 4f9a37c7a0..c4271c3383 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -9,12 +9,14 @@ from typing import cast import torch +from huggingface_hub import hf_hub_download from torch import nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.model_loader.weight_utils import ( + download_gguf, download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, @@ -27,6 +29,7 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.distributed.hsdp import HSDPInferenceConfig +from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter from vllm_omni.diffusion.registry import initialize_model logger = init_logger(__name__) @@ -194,13 +197,36 @@ def get_all_weights( self, model: nn.Module, ) -> Generator[tuple[str, torch.Tensor], None, None]: - sources = cast( - Iterable[DiffusersPipelineLoader.ComponentSource], - getattr(model, "weights_sources", ()), - ) + sources = self._get_weight_sources(model) for source in sources: yield from self._get_weights_iterator(source) + def _get_weight_sources(self, model: nn.Module) -> tuple["ComponentSource", ...]: + return tuple( + cast( + Iterable[DiffusersPipelineLoader.ComponentSource], + getattr(model, "weights_sources", ()), + ) + ) + + def _get_expected_parameter_names(self, model: nn.Module) -> set[str]: + """Return parameter names that should be covered by strict load checks.""" + all_parameter_names = {name for name, _ in model.named_parameters()} + sources = self._get_weight_sources(model) + + # Keep strict behavior if no source metadata exists. + if not sources: + return all_parameter_names + + # Empty prefix means "root" source, i.e. entire model should be covered. + if any(source.prefix == "" for source in sources): + return all_parameter_names + + source_prefixes = tuple(source.prefix for source in sources if source.prefix) + if not source_prefixes: + return all_parameter_names + return {name for name in all_parameter_names if name.startswith(source_prefixes)} + def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights( model_name_or_path=model_config.model, @@ -232,8 +258,11 @@ def load_model( model_cls = resolve_obj_by_qualname(custom_pipeline_name) model = model_cls(od_config=od_config) logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it - self.load_weights(model) + if self._is_gguf_quantization(od_config): + self._load_weights_with_gguf(model, od_config) + else: + # Quantization does not happen in `load_weights` but after it + self.load_weights(model) # Process weights after loading for quantization (e.g., FP8 online quantization) # This is needed for vLLM's quantization methods that need to transform weights @@ -265,7 +294,7 @@ def _process_weights_after_loading(self, model: nn.Module, target_device: torch. module.to(module_device) def load_weights(self, model: nn.Module) -> None: - weights_to_load = {name for name, _ in model.named_parameters()} + weights_to_load = self._get_expected_parameter_names(model) loaded_weights = model.load_weights(self.get_all_weights(model)) self.counter_after_loading_weights = time.perf_counter() @@ -278,12 +307,122 @@ def load_weights(self, model: nn.Module) -> None: # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if loaded_weights is not None: - _ = weights_to_load - loaded_weights - # if weights_not_loaded: - # raise ValueError( - # "Following weights were not initialized from " - # f"checkpoint: {weights_not_loaded}" - # ) + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError(f"Following weights were not initialized from checkpoint: {weights_not_loaded}") + + def _is_gguf_quantization(self, od_config: OmniDiffusionConfig) -> bool: + quant_config = od_config.quantization_config + if quant_config is None: + return False + # Fast path: mapping-style config (e.g., DictConfig) + if isinstance(quant_config, dict): + method = str(quant_config.get("method", "")).lower() + if method != "gguf": + return False + gguf_model = quant_config.get("gguf_model") + if not gguf_model: + raise ValueError("GGUF quantization requires quantization_config.gguf_model") + return True + + # Normal path: DiffusionQuantizationConfig + if not hasattr(quant_config, "get_name"): + # Fallback: if it carries gguf_model, treat as GGUF + gguf_model = getattr(quant_config, "gguf_model", None) + return bool(gguf_model) + is_gguf = quant_config.get_name() == "gguf" + if not is_gguf: + return False + gguf_model = getattr(quant_config, "gguf_model", None) + if gguf_model is None: + raise ValueError("GGUF quantization requires quantization_config.gguf_model") + return True + + def _is_transformer_source(self, source: "ComponentSource") -> bool: + if source.subfolder == "transformer": + return True + return source.prefix.startswith("transformer.") + + def _get_model_loadable_names(self, model: nn.Module) -> set[str]: + # Avoid model.state_dict() here because GGUF uses UninitializedParameter + # which raises during detach(). Collect names directly. + names = {name for name, _ in model.named_parameters()} + names.update(name for name, _ in model.named_buffers()) + return names + + def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str: + if os.path.isfile(gguf_model): + return gguf_model + # repo_id/filename.gguf + if "/" in gguf_model and gguf_model.endswith(".gguf"): + repo_id, filename = gguf_model.rsplit("/", 1) + return hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + cache_dir=self.load_config.download_dir, + ) + # repo_id:quant_type + if "/" in gguf_model and ":" in gguf_model: + repo_id, quant_type = gguf_model.rsplit(":", 1) + return download_gguf( + repo_id, + quant_type, + cache_dir=self.load_config.download_dir, + revision=revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + raise ValueError( + f"Unrecognized GGUF reference: {gguf_model!r} (expected local file, " + "/.gguf, or :)" + ) + + def _get_gguf_weights_iterator( + self, + source: "ComponentSource", + model: nn.Module, + od_config: OmniDiffusionConfig, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + quant_config = od_config.quantization_config + gguf_model = getattr(quant_config, "gguf_model", None) + if gguf_model is None: + raise ValueError("GGUF quantization requires quantization_config.gguf_model") + gguf_file = self._resolve_gguf_model_path(gguf_model, od_config.revision) + adapter = get_gguf_adapter(gguf_file, model, source, od_config) + weights_iter = adapter.weights_iterator() + return ((source.prefix + name, tensor) for (name, tensor) in weights_iter) + + def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConfig) -> set[str]: + sources = self._get_weight_sources(model) + loaded: set[str] = set() + loadable_names: set[str] | None = None + + for source in sources: + if self._is_transformer_source(source): + loaded |= model.load_weights(self._get_gguf_weights_iterator(source, model, od_config)) + + # GGUF checkpoints can be transformer-only or partially quantized. + # Only fall back to HF if this source still has missing loadable weights. + loadable_names = loadable_names or self._get_model_loadable_names(model) + has_missing_for_source = any( + name.startswith(source.prefix) and name not in loaded for name in loadable_names + ) + if not has_missing_for_source: + continue + + hf_iter = self._get_weights_iterator(source) + hf_iter = ( + (name, tensor) for (name, tensor) in hf_iter if name in loadable_names and name not in loaded + ) + loaded |= model.load_weights(hf_iter) + else: + loaded |= model.load_weights(self._get_weights_iterator(source)) + + weights_to_load = self._get_expected_parameter_names(model) + weights_not_loaded = weights_to_load - loaded + if weights_not_loaded: + raise ValueError(f"Following weights were not initialized from checkpoint: {weights_not_loaded}") + return loaded def _load_model_with_hsdp( self, diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py new file mode 100644 index 0000000000..416ebc7a84 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from .base import GGUFAdapter +from .flux2_klein import Flux2KleinGGUFAdapter +from .z_image import ZImageGGUFAdapter + +if TYPE_CHECKING: + from vllm_omni.diffusion.data import OmniDiffusionConfig + from vllm_omni.diffusion.model_loader.diffusers_loader import ( + DiffusersPipelineLoader, + ) + + +def get_gguf_adapter( + gguf_file: str, + model: torch.nn.Module, + source: DiffusersPipelineLoader.ComponentSource, + od_config: OmniDiffusionConfig, +) -> GGUFAdapter: + adapter_classes = (ZImageGGUFAdapter, Flux2KleinGGUFAdapter) + for adapter_cls in adapter_classes: + if adapter_cls.is_compatible(od_config, model, source): + return adapter_cls(gguf_file, model, source, od_config) + model_type = None + if od_config.tf_model_config is not None: + model_type = od_config.tf_model_config.get("model_type") + supported = ", ".join(cls.__name__ for cls in adapter_classes) + raise ValueError( + "No GGUF adapter matched diffusion model " + f"(model_class_name={od_config.model_class_name!r}, model_type={model_type!r}). " + f"Supported adapters: {supported}." + ) + + +__all__ = [ + "GGUFAdapter", + "Flux2KleinGGUFAdapter", + "ZImageGGUFAdapter", + "get_gguf_adapter", +] diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py new file mode 100644 index 0000000000..8794ecff73 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any + +import gguf +import numpy as np +import torch + + +@dataclass +class MappedTensor: + name: str + tensor: Any + tensor_type: Any + row_slice: slice | None = None + swap_scale_shift: bool = False + + +class GGUFAdapter(ABC): + """Base class for model-specific GGUF adapters.""" + + _include_qkv_virtuals: bool = False + _include_add_kv_proj_virtuals: bool = False + _include_to_out_virtuals: bool = False + _include_w13_virtuals: bool = False + _shard_tokens: tuple[str, ...] = () + _prefer_exact_qweight: bool = True + + def __init__(self, gguf_file: str, model: torch.nn.Module, source, od_config) -> None: + self.gguf_file = gguf_file + self.model = model + self.source = source + self.od_config = od_config + + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + return False + + @abstractmethod + def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: + raise NotImplementedError + + def _get_target_module(self) -> torch.nn.Module: + prefix = getattr(self.source, "prefix", "") + return self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + + def _build_allowed_names(self) -> set[str]: + target = self._get_target_module() + allowed = {name for name, _ in target.named_parameters()} + allowed.update(name for name, _ in target.named_buffers()) + for name in list(allowed): + if name.endswith(".qweight"): + allowed.add(name.replace(".qweight", ".weight")) + elif name.endswith(".qweight_type"): + allowed.add(name.replace(".qweight_type", ".weight")) + + virtual_names = set() + for name in allowed: + if self._include_qkv_virtuals and ".to_qkv." in name: + virtual_names.add(name.replace(".to_qkv.", ".to_q.")) + virtual_names.add(name.replace(".to_qkv.", ".to_k.")) + virtual_names.add(name.replace(".to_qkv.", ".to_v.")) + if self._include_add_kv_proj_virtuals and ".add_kv_proj." in name: + virtual_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) + if self._include_w13_virtuals and ".w13." in name: + virtual_names.add(name.replace(".w13.", ".w1.")) + virtual_names.add(name.replace(".w13.", ".w3.")) + if self._include_to_out_virtuals and ".to_out." in name: + virtual_names.add(name.replace(".to_out.", ".to_out.0.")) + allowed.update(virtual_names) + return allowed + + def _build_param_names(self) -> set[str]: + target = self._get_target_module() + return {name for name, _ in target.named_parameters()} + + def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: + if not name.endswith(".weight"): + return None + if self._prefer_exact_qweight: + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + if ".to_out.0." in name: + alt_name = name.replace(".to_out.0.", ".to_out.") + candidate = alt_name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + name = alt_name + for shard_token in self._shard_tokens: + if shard_token in name: + return name.replace(".weight", ".qweight") + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + return None + + def _build_gguf_name_map(self) -> dict[str, str]: + def resolve_model_type() -> str: + cfg = self.od_config.tf_model_config + model_type = None + if cfg is not None: + model_type = cfg.get("model_type") + if model_type: + return model_type + model_class = self.od_config.model_class_name or "" + if model_class.startswith("QwenImage"): + return "qwen_image" + if model_class.startswith("Flux2"): + return "flux" + raise ValueError("Cannot infer gguf model_type for diffusion model.") + + def resolve_arch(model_type: str): + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + return key + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + + def resolve_num_layers(target_module: torch.nn.Module) -> int: + if hasattr(target_module, "transformer_blocks"): + return len(getattr(target_module, "transformer_blocks")) + if hasattr(target_module, "double_blocks"): + return len(getattr(target_module, "double_blocks")) + cfg = self.od_config.tf_model_config + if cfg is not None: + for key in ("num_hidden_layers", "num_layers", "n_layers"): + value = cfg.get(key) + if isinstance(value, int) and value > 0: + return value + raise ValueError("Cannot infer gguf num_layers for diffusion model.") + + def get_target_module(root: torch.nn.Module, prefix: str) -> torch.nn.Module: + if not prefix: + return root + prefix = prefix.rstrip(".") + if hasattr(root, "get_submodule"): + return root.get_submodule(prefix) + current = root + for part in prefix.split("."): + current = getattr(current, part) + return current + + def split_name(name: str) -> tuple[str, str]: + if name.endswith("_weight"): + return name[:-7], "weight" + if "." in name: + base, suffix = name.rsplit(".", 1) + return base, suffix + return name, "" + + reader = gguf.GGUFReader(self.gguf_file) + gguf_tensor_names = {tensor.name for tensor in reader.tensors} + + model_type = resolve_model_type() + arch = resolve_arch(model_type) + target_module = get_target_module(self.model, self.source.prefix) + num_layers = resolve_num_layers(target_module) + name_map = gguf.get_tensor_name_map(arch, num_layers) + + gguf_to_model_map: dict[str, str] = {} + for name, _ in target_module.named_parameters(): + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: + continue + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + + for name, _ in target_module.named_buffers(): + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: + continue + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + + if not gguf_to_model_map: + raise RuntimeError(f"No GGUF tensors were mapped for model_class_name={self.od_config.model_class_name!r}.") + return gguf_to_model_map + + +# FIXME(Isotr0py): Sync implemnentation with upstream vLLM? +def gguf_quant_weights_iterator(gguf_file: str) -> Generator[tuple[str, torch.Tensor]]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors. + Be careful of the order of yielding weight types and weights data, + we have to yield all weight types first before yielding any weights. + Otherwise it would cause issue when loading weights with for packed + layer with different quant types. + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + weight_type = tensor.tensor_type + name = tensor.name + + if weight_type.name not in ("F32", "F16"): + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + weight = tensor.data + weight_type = tensor.tensor_type + name = tensor.name + if weight_type.name not in ("F32", "F16"): + name = name.replace("weight", "qweight") + if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: + # BF16 is currently the only "quantization" type that isn't + # actually quantized but is read as a raw byte tensor. + # Reinterpret as `torch.bfloat16` tensor. + weight = weight.view(np.uint16) + if reader.byte_order == "S": + # GGUF endianness != system endianness + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) + yield name, param diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py new file mode 100644 index 0000000000..4ee6439b28 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections.abc import Iterable + +import torch +from vllm.model_executor.models.utils import WeightsMapper + +from .base import GGUFAdapter, gguf_quant_weights_iterator + +FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { + "single_blocks.": "single_transformer_blocks.", + # Image and text input projections + "img_in": "x_embedder", + "txt_in": "context_embedder", + # Timestep and guidance embeddings + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", + # Modulation parameters + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + # Final output layer + # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params + "final_layer.linear": "proj_out", +} + +FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear", +} + +FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { + "double_blocks.": "transformer_blocks.", + # Handle fused QKV projections separately as we need to break into Q, K, V projections + "img_attn.norm.query_norm": "attn.norm_q", + "img_attn.norm.key_norm": "attn.norm_k", + "img_attn.proj": "attn.to_out.0", + "img_mlp.0": "ff.linear_in", + "img_mlp.2": "ff.linear_out", + "txt_attn.norm.query_norm": "attn.norm_added_q", + "txt_attn.norm.key_norm": "attn.norm_added_k", + "txt_attn.proj": "attn.to_add_out", + "txt_mlp.0": "ff_context.linear_in", + "txt_mlp.2": "ff_context.linear_out", + # Additional for fuse qkv + "img_attn.qkv": "attn.to_qkv", + "txt_attn.qkv": "attn.add_kv_proj", +} + +FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { + "linear1": "attn.to_qkv_mlp_proj", + "linear2": "attn.to_out", + "norm.query_norm": "attn.norm_q", + "norm.key_norm": "attn.norm_k", +} + + +class Flux2KleinGGUFAdapter(GGUFAdapter): + """GGUF adapter for Flux2-Klein models with qkv splitting and adaLN swap.""" + + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + model_class = od_config.model_class_name or "" + if model_class.startswith("Flux2"): + return True + cfg = od_config.tf_model_config + if cfg is not None: + model_type = str(cfg.get("model_type", "")).lower() + if model_type.startswith("flux"): + return True + return False + + gguf_to_hf_mapper = WeightsMapper( + # double_stream_modulation + orig_to_new_prefix=FLUX2_TRANSFORMER_KEYS_RENAME_DICT | FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP, + orig_to_new_substr=FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP | FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP, + ) + + def weights_iterator(self) -> Iterable[tuple[str, torch.Tensor]]: + def custom_weights_adapter( + weights: Iterable[tuple[str, torch.Tensor]], + ) -> Iterable[tuple[str, torch.Tensor]]: + for name, weight in weights: + # Handle the special case for adaLN modulation parameters that require swapping shift and scale + if name.endswith(".scale"): + name = name.replace(".scale", ".weight") + if name == "norm_out.linear.weight": + shift, scale = weight.chunk(2, dim=0) + weight = torch.cat([scale, shift], dim=0) + yield name, weight + else: + yield name, weight + + weights = gguf_quant_weights_iterator(self.gguf_file) + weights = self.gguf_to_hf_mapper.apply(weights) + yield from custom_weights_adapter(weights) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py new file mode 100644 index 0000000000..7d89633559 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections.abc import Iterable + +import torch +from vllm.model_executor.models.utils import WeightsMapper + +from .base import GGUFAdapter, gguf_quant_weights_iterator + +Z_IMAGE_KEYS_RENAME_DICT = { + "final_layer.": "all_final_layer.2-1.", + "x_embedder.": "all_x_embedder.2-1.", + ".attention.qkv": ".attention.to_qkv", + ".attention.k_norm": ".attention.norm_k", + ".attention.q_norm": ".attention.norm_q", + ".attention.out": ".attention.to_out.0", + "model.diffusion_model.": "", +} + + +class ZImageGGUFAdapter(GGUFAdapter): + """GGUF adapter for Z-Image models with QKV/FFN shard support.""" + + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + model_class = od_config.model_class_name or "" + if model_class.startswith("ZImage"): + return True + cfg = od_config.tf_model_config + if cfg is not None: + model_type = str(cfg.get("model_type", "")).lower() + if model_type in {"z_image", "zimage", "z-image"}: + return True + return False + + gguf_to_hf_mapper = WeightsMapper( + orig_to_new_substr=Z_IMAGE_KEYS_RENAME_DICT, + ) + + def weights_iterator(self) -> Iterable[tuple[str, torch.Tensor]]: + weights = gguf_quant_weights_iterator(self.gguf_file) + yield from self.gguf_to_hf_mapper.apply(weights) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ee10d2e0e4..5ee5ee440a 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -17,7 +17,7 @@ from collections.abc import Iterable from types import SimpleNamespace -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -41,6 +41,9 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.layers.rope import RotaryEmbedding +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + class Flux2SwiGLU(nn.Module): """SwiGLU activation used by Flux2.""" @@ -62,6 +65,7 @@ def __init__( mult: float = 3.0, inner_dim: int | None = None, bias: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() if inner_dim is None: @@ -73,6 +77,7 @@ def __init__( [inner_dim, inner_dim], bias=bias, return_bias=False, + quant_config=quant_config, ) self.act_fn = Flux2SwiGLU() self.linear_out = RowParallelLinear( @@ -81,6 +86,7 @@ def __init__( bias=bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -103,6 +109,7 @@ def __init__( eps: float = 1e-5, out_dim: int = None, elementwise_affine: bool = True, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.head_dim = dim_head @@ -118,6 +125,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=bias, + quant_config=quant_config, ) self.query_num_heads = self.to_qkv.num_heads self.kv_num_heads = self.to_qkv.num_kv_heads @@ -133,6 +141,7 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ), nn.Dropout(dropout), ] @@ -146,6 +155,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=added_proj_bias, + quant_config=quant_config, ) self.add_query_num_heads = self.add_kv_proj.num_heads self.add_kv_num_heads = self.add_kv_proj.num_kv_heads @@ -155,6 +165,7 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -251,6 +262,7 @@ def __init__( elementwise_affine: bool = True, mlp_ratio: float = 4.0, mlp_mult_factor: int = 2, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.head_dim = dim_head @@ -269,6 +281,7 @@ def __init__( self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias, gather_output=True, + quant_config=quant_config, ) self.mlp_act_fn = Flux2SwiGLU() @@ -280,6 +293,7 @@ def __init__( self.out_dim, bias=out_bias, gather_output=True, + quant_config=quant_config, ) self.rope = RotaryEmbedding(is_neox_style=False) self.attn = Attention( @@ -342,6 +356,7 @@ def __init__( mlp_ratio: float = 3.0, eps: float = 1e-6, bias: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) @@ -355,6 +370,7 @@ def __init__( eps=eps, mlp_ratio=mlp_ratio, mlp_mult_factor=2, + quant_config=quant_config, ) def forward( @@ -402,6 +418,7 @@ def __init__( mlp_ratio: float = 3.0, eps: float = 1e-6, bias: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) @@ -417,13 +434,20 @@ def __init__( added_proj_bias=bias, out_bias=bias, eps=eps, + quant_config=quant_config, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) - self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias, quant_config=quant_config) self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) - self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + self.ff_context = Flux2FeedForward( + dim=dim, + dim_out=dim, + mult=mlp_ratio, + bias=bias, + quant_config=quant_config, + ) def forward( self, @@ -580,6 +604,7 @@ def __init__( rope_theta: int = 2000, eps: float = 1e-6, guidance_embeds: bool = True, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.out_channels = out_channels or in_channels @@ -625,6 +650,7 @@ def __init__( mlp_ratio=mlp_ratio, eps=eps, bias=False, + quant_config=quant_config, ) for _ in range(num_layers) ] @@ -639,6 +665,7 @@ def __init__( mlp_ratio=mlp_ratio, eps=eps, bias=False, + quant_config=quant_config, ) for _ in range(num_single_layers) ] @@ -728,9 +755,9 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), + (".to_qkv.", ".to_q.", "q"), + (".to_qkv.", ".to_k.", "k"), + (".to_qkv.", ".to_v.", "v"), (".add_kv_proj", ".add_q_proj", "q"), (".add_kv_proj", ".add_k_proj", "k"), (".add_kv_proj", ".add_v_proj", "v"), @@ -744,25 +771,32 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if "to_qkvkv_mlp_proj" in name: - name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") - if "to_qkv_mlp_proj" in name: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - continue + original_name = name + mapped = False for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in original_name: continue - name = name.replace(weight_name, param_name) - param = params_dict[name] + name = original_name.replace(weight_name, param_name) + param = params_dict.get(name) + if param is None: + break weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) + mapped = True break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + if mapped: + continue + + name = original_name + if name not in params_dict and ".to_out.0." in name: + name = name.replace(".to_out.0.", ".to_out.") + # Some GGUF checkpoints include quantized tensors for modules that + # are intentionally left unquantized in this model. + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index e1ef706c3f..d43748380b 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -43,6 +43,7 @@ Flux2Transformer2DModel, ) from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific @@ -230,7 +231,8 @@ def __init__( ).to(self._execution_device) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) - self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.transformer = Flux2Transformer2DModel(quant_config=quant_config, **transformer_kwargs) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) @@ -993,4 +995,8 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loaded_weights = loader.load_weights(weights) + # Record components loaded by diffusers submodules to satisfy strict checks. + loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} + loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} + return loaded_weights diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index c48d68efd6..22dfc06c5a 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -30,6 +30,7 @@ from vllm_omni.diffusion.models.interface import SupportAudioOutput from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs logger = init_logger(__name__) @@ -127,8 +128,9 @@ def __init__( local_files_only=local_files_only, ).to(self.device) - # Initialize our custom transformer (weights loaded via load_weights) - self.transformer = StableAudioDiTModel(od_config=od_config) + # Initialize transformer from HF config to keep architecture aligned with checkpoint. + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, StableAudioDiTModel) + self.transformer = StableAudioDiTModel(od_config=od_config, **transformer_kwargs) # Load scheduler self.scheduler = CosineDPMSolverMultistepScheduler.from_pretrained( diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index e025440a79..2c6ef4c86f 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -644,4 +644,8 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loaded_weights = loader.load_weights(weights) + # Record components loaded by diffusers submodules to satisfy strict checks. + loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} + loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} + return loaded_weights diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index a4f073faa5..83d0ecb9b9 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -29,6 +29,7 @@ from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -207,21 +208,27 @@ def validate_zimage_tp_constraints( class TimestepEmbedder(nn.Module): - def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + def __init__( + self, out_size, mid_size=None, frequency_embedding_size=256, quant_config: "QuantizationConfig | None" = None + ): super().__init__() if mid_size is None: mid_size = out_size self.mlp = nn.Sequential( - nn.Linear( + ReplicatedLinear( frequency_embedding_size, mid_size, bias=True, + quant_config=quant_config, + return_bias=False, ), nn.SiLU(), - nn.Linear( + ReplicatedLinear( mid_size, out_size, bias=True, + quant_config=quant_config, + return_bias=False, ), ) @@ -241,7 +248,7 @@ def timestep_embedding(t, dim, max_period=10000): def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - weight_dtype = self.mlp[0].weight.dtype + weight_dtype = self.mlp[0].bias.dtype if weight_dtype.is_floating_point: t_freq = t_freq.to(weight_dtype) t_emb = self.mlp(t_freq) @@ -420,7 +427,9 @@ def __init__( self.modulation = modulation if modulation: self.adaLN_modulation = nn.Sequential( - nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ReplicatedLinear( + min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config + ), ) def forward( @@ -473,14 +482,18 @@ def forward( class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): + def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.linear = ReplicatedLinear( + hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False + ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ReplicatedLinear( + min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False + ), ) def forward(self, x, c): @@ -652,10 +665,18 @@ def __init__( all_x_embedder = {} all_final_layer = {} for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + x_embedder = ReplicatedLinear( + f_patch_size * patch_size * patch_size * in_channels, + dim, + bias=True, + quant_config=quant_config, + return_bias=False, + ) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + final_layer = FinalLayer( + dim, patch_size * patch_size * f_patch_size * self.out_channels, quant_config=quant_config + ) all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer self.all_x_embedder = nn.ModuleDict(all_x_embedder) @@ -690,10 +711,10 @@ def __init__( for layer_id in range(n_refiner_layers) ] ) - self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024, quant_config=quant_config) self.cap_embedder = nn.Sequential( RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear(cap_feat_dim, dim, bias=True), + ReplicatedLinear(cap_feat_dim, dim, bias=True, return_bias=False, quant_config=quant_config), ) self.x_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -957,9 +978,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) # self-attn - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), + (".to_qkv.", ".to_q.", "q"), + (".to_qkv.", ".to_k.", "k"), + (".to_qkv.", ".to_v.", "v"), # ffn (".w13", ".w1", 0), (".w13", ".w3", 1), diff --git a/vllm_omni/diffusion/quantization/__init__.py b/vllm_omni/diffusion/quantization/__init__.py index cc1bb547f7..d297d51f18 100644 --- a/vllm_omni/diffusion/quantization/__init__.py +++ b/vllm_omni/diffusion/quantization/__init__.py @@ -28,6 +28,7 @@ from .base import DiffusionQuantizationConfig from .fp8 import DiffusionFp8Config +from .gguf import DiffusionGgufConfig if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import ( @@ -40,6 +41,7 @@ # To add a new method, create a new config class and register it here _QUANT_CONFIG_REGISTRY: dict[str, type[DiffusionQuantizationConfig]] = { "fp8": DiffusionFp8Config, + "gguf": DiffusionGgufConfig, } SUPPORTED_QUANTIZATION_METHODS = list(_QUANT_CONFIG_REGISTRY.keys()) @@ -108,6 +110,7 @@ def get_vllm_quant_config_for_layers( __all__ = [ "DiffusionQuantizationConfig", "DiffusionFp8Config", + "DiffusionGgufConfig", "get_diffusion_quant_config", "get_vllm_quant_config_for_layers", "SUPPORTED_QUANTIZATION_METHODS", diff --git a/vllm_omni/diffusion/quantization/base.py b/vllm_omni/diffusion/quantization/base.py index 0cd9e4147e..17e6d32ead 100644 --- a/vllm_omni/diffusion/quantization/base.py +++ b/vllm_omni/diffusion/quantization/base.py @@ -31,15 +31,14 @@ class DiffusionQuantizationConfig(ABC): # The underlying vLLM config instance _vllm_config: "QuantizationConfig | None" = None - @classmethod - def get_name(cls) -> str: + def get_name(self) -> str: """Return the quantization method name (e.g., 'fp8', 'int8'). - By default, delegates to the underlying vLLM config class. + By default, delegates to the underlying vLLM config instance. """ - if cls.quant_config_cls is not None: - return cls.quant_config_cls.get_name() - raise NotImplementedError("Subclass must set quant_config_cls or override get_name()") + if self._vllm_config is not None: + return self._vllm_config.get_name() + raise NotImplementedError("Subclass must initialize _vllm_config or override get_name().") def get_vllm_quant_config(self) -> "QuantizationConfig | None": """Return the underlying vLLM QuantizationConfig for linear layers.""" diff --git a/vllm_omni/diffusion/quantization/gguf.py b/vllm_omni/diffusion/quantization/gguf.py new file mode 100644 index 0000000000..85ec2ee33a --- /dev/null +++ b/vllm_omni/diffusion/quantization/gguf.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GGUF quantization config for diffusion transformers.""" + +import gguf +import torch +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.gguf import ( + UNQUANTIZED_TYPES, + GGUFConfig, + GGUFLinearMethod, + LinearBase, + QuantizeMethodBase, + UnquantizedLinearMethod, + is_layer_skipped_gguf, +) + +from .base import DiffusionQuantizationConfig + + +def dequant_gemm_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) + return x @ weight.T + + +class DiffusionGGUFLinearMethod(GGUFLinearMethod): + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + # Dequantize + GEMM path: torch.matmul multiplies over the last + # dimension and broadcasts leading dimensions, so no 2D flattening + # is required here. + shard_id = getattr(layer.qweight, "shard_id", None) + + if shard_id: + # dequantize shard weights respectively + shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + qweight = layer.qweight + result = [] + for idx in shard_id: + start, end, offset = layer.qweight.shard_offset_map[idx] + qweight_type = layer.qweight_type.shard_weight_type[idx] + result.append(dequant_gemm_gguf(x, qweight[start:end, :offset].contiguous(), qweight_type)) + out = torch.cat(result, axis=-1) + else: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + out = dequant_gemm_gguf(x, qweight, qweight_type) + if bias is not None: + out.add_(bias) + return out + + +class _GGUFConfig(GGUFConfig): + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> "QuantizeMethodBase": + if isinstance(layer, LinearBase): + if is_layer_skipped_gguf(prefix, self.unquantized_modules, self.packed_modules_mapping): + return UnquantizedLinearMethod() + return DiffusionGGUFLinearMethod(self) + return None + + +class DiffusionGgufConfig(DiffusionQuantizationConfig): + """GGUF quantization config for diffusion transformers. + + This is a thin wrapper around vLLM's GGUFConfig and also carries + the GGUF model reference for loader use. + + Args: + gguf_model: GGUF model path or HF reference (repo/file or repo:quant_type) + unquantized_modules: Optional list of module name patterns to skip GGUF + quantization. Note: diffusion linear layers often use short prefixes + (e.g., "to_qkv"), so these patterns are matched as substrings. + """ + + quant_config_cls = GGUFConfig + + def __init__( + self, + gguf_model: str | None = None, + unquantized_modules: list[str] | None = None, + ) -> None: + self.gguf_model = gguf_model + self.unquantized_modules = unquantized_modules or [] + + self._vllm_config = _GGUFConfig(unquantized_modules=self.unquantized_modules) diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index 6f14aa2643..dac0963399 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -74,6 +74,7 @@ class OmniEngineArgs(EngineArgs): stage_connector_spec: dict[str, Any] = field(default_factory=dict) async_chunk: bool = False omni_kv_config: dict | None = None + quantization_config: Any | None = None worker_type: str | None = None def __post_init__(self) -> None: @@ -230,6 +231,7 @@ class AsyncOmniEngineArgs(AsyncEngineArgs): stage_connector_spec: dict[str, Any] = field(default_factory=dict) async_chunk: bool = False omni_kv_config: dict | None = None + quantization_config: Any | None = None worker_type: str | None = None def __post_init__(self) -> None: diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 57e20e409b..c9ed7150f7 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -6,6 +6,7 @@ """ import argparse +import json import os import signal from typing import Any @@ -209,6 +210,15 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Ring Sequence Parallelism degree for diffusion models. " "Equivalent to setting DiffusionParallelConfig.ring_degree.", ) + omni_config_group.add_argument( + "--quantization-config", + type=json.loads, + default=None, + help=( + "JSON string for diffusion quantization_config. " + 'Example: \'{"method":"gguf","gguf_model":"/path/to/model.gguf"}\'.' + ), + ) # HSDP (Hybrid Sharded Data Parallel) parameters omni_config_group.add_argument( diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 59be966c33..deda227180 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -285,6 +285,13 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st if lora_scale is not None: if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: cfg.engine_args.lora_scale = lora_scale + quantization_config = kwargs.get("quantization_config") + if quantization_config is not None: + if ( + not hasattr(cfg.engine_args, "quantization_config") + or cfg.engine_args.quantization_config is None + ): + cfg.engine_args.quantization_config = quantization_config except Exception as e: logger.warning("Failed to inject LoRA config for stage: %s", e)