From df1a0bc3148552f9666ca86c09c707465f86475d Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Mon, 9 Feb 2026 15:46:12 +0800 Subject: [PATCH 01/62] cherry pick 1034 Signed-off-by: David Chen <530634352@qq.com> --- docs/.nav.yml | 3 + .../diffusion/quantization/fp8_online.md | 75 ++++++++++ .../diffusion/quantization/overview.md | 18 +++ docs/user_guide/diffusion_acceleration.md | 64 +++++--- .../text_to_image/text_to_image.py | 35 +++++ tests/diffusion/quantization/__init__.py | 2 + .../diffusion/quantization/test_fp8_config.py | 137 ++++++++++++++++++ vllm_omni/diffusion/data.py | 47 +++++- .../model_loader/diffusers_loader.py | 29 ++++ .../models/qwen_image/pipeline_qwen_image.py | 6 +- .../qwen_image/qwen_image_transformer.py | 45 +++++- .../models/z_image/pipeline_z_image.py | 5 +- .../models/z_image/z_image_transformer.py | 39 ++++- vllm_omni/diffusion/quantization/__init__.py | 114 +++++++++++++++ vllm_omni/diffusion/quantization/base.py | 61 ++++++++ vllm_omni/diffusion/quantization/fp8.py | 50 +++++++ 16 files changed, 700 insertions(+), 30 deletions(-) create mode 100644 docs/user_guide/diffusion/quantization/fp8_online.md create mode 100644 docs/user_guide/diffusion/quantization/overview.md create mode 100644 tests/diffusion/quantization/__init__.py create mode 100644 tests/diffusion/quantization/test_fp8_config.py create mode 100644 vllm_omni/diffusion/quantization/__init__.py create mode 100644 vllm_omni/diffusion/quantization/base.py create mode 100644 vllm_omni/diffusion/quantization/fp8.py diff --git a/docs/.nav.yml b/docs/.nav.yml index adf7321d7a..06bee0e8de 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -42,6 +42,9 @@ nav: - Overview: user_guide/diffusion_acceleration.md - TeaCache: user_guide/diffusion/teacache.md - Cache-DiT: user_guide/diffusion/cache_dit_acceleration.md + - Quantization: + - Overview: user_guide/diffusion/quantization/overview.md + - Online FP8: user_guide/diffusion/quantization/fp8_online.md - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md - ComfyUI: features/comfyui.md diff --git a/docs/user_guide/diffusion/quantization/fp8_online.md b/docs/user_guide/diffusion/quantization/fp8_online.md new file mode 100644 index 0000000000..65a6329690 --- /dev/null +++ b/docs/user_guide/diffusion/quantization/fp8_online.md @@ -0,0 +1,75 @@ +# Online FP8 Quantization + +## Overview + +Online FP8 converts BF16/FP16 weights to FP8 at model load time. No calibration or pre-quantized checkpoint needed. + +Depending on the model, either all layers can be quantized, or some sensitive layers should stay in BF16. See the [per-model table](#supported-models) for which case applies. + +## Configuration + +1. **Python API**: set `quantization="fp8"`. To skip sensitive layers, use `quantization_config` with `ignored_layers`. + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# All layers quantized +omni = Omni(model="", quantization="fp8") + +# Skip sensitive layers +omni = Omni( + model="", + quantization_config={ + "method": "fp8", + "ignored_layers": [""], + }, +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + +2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. + +```bash +# All layers +python text_to_image.py --model --quantization fp8 + +# Skip sensitive layers +python text_to_image.py --model --quantization fp8 --ignored-layers "img_mlp" + +# Online serving +vllm serve --omni --quantization fp8 +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `method` | str | — | Quantization method (`"fp8"`) | +| `ignored_layers` | list[str] | `[]` | Layer name patterns to keep in BF16 | +| `activation_scheme` | str | `"dynamic"` | `"dynamic"` (no calibration) or `"static"` | +| `weight_block_size` | list[int] \| None | `None` | Block size for block-wise weight quantization | + +The available `ignored_layers` names depend on the model architecture (e.g., `to_qkv`, `to_out`, `img_mlp`, `txt_mlp`). Consult the transformer source for your target model. + +## Supported Models + +| Model | HF Models | Recommendation | `ignored_layers` | +|-------|-----------|---------------|------------------| +| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | All layers | None | +| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Skip sensitive layers | `img_mlp` | + +## Combining with Other Features + +FP8 quantization can be combined with cache acceleration: + +```python +omni = Omni( + model="", + quantization="fp8", + cache_backend="tea_cache", + cache_config={"rel_l1_thresh": 0.2}, +) +``` diff --git a/docs/user_guide/diffusion/quantization/overview.md b/docs/user_guide/diffusion/quantization/overview.md new file mode 100644 index 0000000000..75aa65f223 --- /dev/null +++ b/docs/user_guide/diffusion/quantization/overview.md @@ -0,0 +1,18 @@ +# Quantization for Diffusion Transformers + +vLLM-Omni supports quantization of DiT linear layers to reduce memory usage and accelerate inference. + +## Supported Methods + +| Method | Guide | +|--------|-------| +| FP8 | [Online FP8](fp8_online.md) | + +## Device Compatibility + +| GPU Generation | Example GPUs | FP8 Mode | +|---------------|-------------------|----------| +| Turing (SM 75+) | T4, RTX 2080 | Weight-only via Marlin kernel | +| Ada/Hopper (SM 89+) | RTX 4090, H100, H200 | Full W8A8 with native hardware | + +Kernel selection is automatic. diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 859f8c0a22..6760194b15 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -1,6 +1,6 @@ # Diffusion Acceleration Overview -vLLM-Omni supports various cache acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, and **parallelism methods** that distribute the computation across multiple devices. +vLLM-Omni supports various acceleration methods to speed up diffusion model inference with minimal quality degradation. These include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, **parallelism methods** that distribute the computation across multiple devices, and **quantization methods** that reduce numerical precision of transformer layers. ## Supported Acceleration Methods @@ -14,6 +14,10 @@ vLLM-Omni currently supports two main cache acceleration backends: Both methods can provide significant speedups (typically **1.5x-2.0x**) while maintaining high output quality. +vLLM-Omni also supports quantization methods: + +3. **[FP8 Quantization](diffusion/quantization/overview.md)** - Reduces DiT linear layers from BF16 to FP8, providing ~1.28x speedup with minimal quality loss. Supports per-layer skip for sensitive layers. + vLLM-Omni also supports parallelism methods for diffusion models, including: 1. [Ulysses-SP](diffusion/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. @@ -31,32 +35,38 @@ vLLM-Omni also supports parallelism methods for diffusion models, including: | **TeaCache** | `cache_backend="tea_cache"` | Simple, adaptive caching with minimal configuration | Quick setup, balanced speed/quality | | **Cache-DiT** | `cache_backend="cache_dit"` | Advanced caching with multiple techniques (DBCache, TaylorSeer, SCM) | Maximum acceleration, fine-grained control | +### Quantization Methods + +| Method | Configuration | Description | Best For | +|--------|--------------|-------------|----------| +| **FP8** | `quantization="fp8"` | FP8 W8A8 on Ada/Hopper, weight-only on older GPUs | Memory reduction, inference speedup | + ## Supported Models The following table shows which models are currently supported by each acceleration method: ### ImageGen -| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | -|-------|------------------|:----------:|:-----------:|:-----------:|:----------------:|:----------------:| -| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ✅ | ✅ | ✅ | -| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ✅ | ✅ | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ❌ | ❌ | ❌ | -| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ✅ | -| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | -| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ❌ | +| Model | Model Identifier | TeaCache | Cache-DiT | FP8 | Ulysses-SP | Ring-Attention | CFG-Parallel | +|-------|------------------|:----------:|:-----------:|:---:|:-----------:|:----------------:|:----------------:| +| **LongCat-Image** | `meituan-longcat/LongCat-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | +| **LongCat-Image-Edit** | `meituan-longcat/LongCat-Image-Edit` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | +| **Ovis-Image** | `OvisAI/Ovis-Image` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| **Qwen-Image-Edit** | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | +| **Qwen-Image-Edit-2509** | `Qwen/Qwen-Image-Edit-2509` | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | +| **Qwen-Image-Layered** | `Qwen/Qwen-Image-Layered` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | +| **Stable-Diffusion3.5** | `stabilityai/stable-diffusion-3.5` | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | +| **Bagel** | `ByteDance-Seed/BAGEL-7B-MoT` | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | +| **FLUX.1-dev** | `black-forest-labs/FLUX.1-dev` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ### VideoGen -| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel | -|-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| -| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | +| Model | Model Identifier | TeaCache | Cache-DiT | FP8 | Ulysses-SP | Ring-Attention | CFG-Parallel | +|-------|------------------|:--------:|:---------:|:---:|:----------:|:--------------:|:----------------:| +| **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ## Performance Benchmarks @@ -227,11 +237,29 @@ outputs = omni.generate( ) ``` +### Using FP8 Quantization + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="", + quantization="fp8", +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + ## Documentation For detailed information on each acceleration method: - **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices - **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[FP8 Quantization Guide](diffusion/quantization/overview.md)** - FP8 quantization for DiT models with per-layer control - **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. - **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index a79e5d640d..adca5f6f8b 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -5,6 +5,7 @@ import os import time from pathlib import Path +from typing import Any import torch @@ -124,6 +125,24 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + choices=["fp8"], + help="Quantization method for the transformer. " + "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). " + "Default: None (no quantization, uses BF16).", + ) + parser.add_argument( + "--ignored-layers", + type=str, + default=None, + help="Comma-separated list of layer name patterns to skip quantization. " + "Only used when --quantization is set. " + "Available layers: to_qkv, to_out, add_kv_proj, to_add_out, img_mlp, txt_mlp, proj_out. " + "Example: --ignored-layers 'add_kv_proj,to_add_out'", + ) parser.add_argument( "--vae_use_slicing", action="store_true", @@ -181,6 +200,18 @@ def main(): # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + # Build quantization kwargs: use quantization_config dict when + # ignored_layers is specified so the list flows through OmniDiffusionConfig + quant_kwargs: dict[str, Any] = {} + ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None + if args.quantization and ignored_layers: + quant_kwargs["quantization_config"] = { + "method": args.quantization, + "ignored_layers": ignored_layers, + } + elif args.quantization: + quant_kwargs["quantization"] = args.quantization + omni = Omni( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, @@ -193,6 +224,7 @@ def main(): parallel_config=parallel_config, enforce_eager=args.enforce_eager, enable_cpu_offload=args.enable_cpu_offload, + **quant_kwargs, ) if profiler_enabled: @@ -205,6 +237,9 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}") + if ignored_layers: + print(f" Ignored layers: {ignored_layers}") print( f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, " f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}" diff --git a/tests/diffusion/quantization/__init__.py b/tests/diffusion/quantization/__init__.py new file mode 100644 index 0000000000..208f01a7cb --- /dev/null +++ b/tests/diffusion/quantization/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/diffusion/quantization/test_fp8_config.py b/tests/diffusion/quantization/test_fp8_config.py new file mode 100644 index 0000000000..57661d2990 --- /dev/null +++ b/tests/diffusion/quantization/test_fp8_config.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for FP8 quantization config.""" + +import pytest + + +def test_fp8_config_creation(): + """Test that FP8 config can be created.""" + from vllm_omni.diffusion.quantization import get_diffusion_quant_config + + config = get_diffusion_quant_config("fp8") + assert config is not None + assert config.get_name() == "fp8" + + +def test_vllm_config_extraction(): + """Test that vLLM config can be extracted from diffusion config.""" + from vllm_omni.diffusion.quantization import ( + get_diffusion_quant_config, + get_vllm_quant_config_for_layers, + ) + + diff_config = get_diffusion_quant_config("fp8") + vllm_config = get_vllm_quant_config_for_layers(diff_config) + assert vllm_config is not None + assert vllm_config.activation_scheme == "dynamic" + + +def test_none_quantization(): + """Test that None quantization returns None config.""" + from vllm_omni.diffusion.quantization import ( + get_diffusion_quant_config, + get_vllm_quant_config_for_layers, + ) + + config = get_diffusion_quant_config(None) + assert config is None + vllm_config = get_vllm_quant_config_for_layers(config) + assert vllm_config is None + + +def test_invalid_quantization(): + """Test that invalid quantization method raises error.""" + from vllm_omni.diffusion.quantization import get_diffusion_quant_config + + with pytest.raises(ValueError, match="Unknown quantization method"): + get_diffusion_quant_config("invalid_method") + + +def test_fp8_config_with_custom_params(): + """Test FP8 config with custom parameters.""" + from vllm_omni.diffusion.quantization import get_diffusion_quant_config + + config = get_diffusion_quant_config( + "fp8", + activation_scheme="static", + ignored_layers=["proj_out"], + ) + assert config is not None + assert config.activation_scheme == "static" + assert "proj_out" in config.ignored_layers + + +def test_supported_methods(): + """Test that supported methods list is correct.""" + from vllm_omni.diffusion.quantization import SUPPORTED_QUANTIZATION_METHODS + + assert "fp8" in SUPPORTED_QUANTIZATION_METHODS + + +def test_quantization_integration(): + """Test end-to-end quantization flow through OmniDiffusionConfig.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + # Test with quantization string only + config = OmniDiffusionConfig(model="test", quantization="fp8") + assert config.quantization_config is not None + assert config.quantization_config.get_name() == "fp8" + + # Test with quantization_config dict + config2 = OmniDiffusionConfig( + model="test", + quantization_config={"method": "fp8", "activation_scheme": "static"}, + ) + assert config2.quantization_config is not None + assert config2.quantization_config.get_name() == "fp8" + assert config2.quantization_config.activation_scheme == "static" + + # Test that vLLM config can be extracted + vllm_config = config.quantization_config.get_vllm_quant_config() + assert vllm_config is not None + + +def test_quantization_dict_not_mutated(): + """Test that passing a dict to quantization_config doesn't mutate it.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + original_dict = {"method": "fp8", "activation_scheme": "static"} + dict_copy = original_dict.copy() + + OmniDiffusionConfig(model="test", quantization_config=original_dict) + + # Original dict should be unchanged + assert original_dict == dict_copy + + +def test_quantization_conflicting_methods_warning(caplog): + """Test warning when quantization and quantization_config['method'] conflict.""" + import logging + + from vllm_omni.diffusion.data import OmniDiffusionConfig + + with caplog.at_level(logging.WARNING): + config = OmniDiffusionConfig( + model="test", + quantization="fp8", # This should be overridden + quantization_config={"method": "fp8", "activation_scheme": "static"}, + ) + # No warning when methods match + assert config.quantization_config is not None + + +def test_fp8_delegates_to_vllm_config(): + """Test that DiffusionFp8Config delegates to vLLM's Fp8Config.""" + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + from vllm_omni.diffusion.quantization import DiffusionFp8Config + + # Test that quant_config_cls is set correctly + assert DiffusionFp8Config.quant_config_cls is Fp8Config + + # Test that get_name() delegates to vLLM + assert DiffusionFp8Config.get_name() == Fp8Config.get_name() + + # Test that get_min_capability() delegates to vLLM + assert DiffusionFp8Config.get_min_capability() == Fp8Config.get_min_capability() diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index f884fb6f17..ecf78c34d8 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -4,9 +4,9 @@ import enum import os import random -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass, field, fields -from typing import Any +from typing import TYPE_CHECKING, Any import torch from pydantic import model_validator @@ -16,6 +16,12 @@ from vllm_omni.diffusion.utils.network_utils import is_port_available +if TYPE_CHECKING: + from vllm_omni.diffusion.quantization import DiffusionQuantizationConfig + +# Import after TYPE_CHECKING to avoid circular imports at runtime +# The actual import is deferred to __post_init__ to avoid import order issues + logger = init_logger(__name__) @@ -365,6 +371,11 @@ class OmniDiffusionConfig: # Omni configuration (injected from stage config) omni_kv_config: dict[str, Any] = field(default_factory=dict) + # Quantization settings + # Supported methods: "fp8" (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs) + quantization: str | None = None + quantization_config: "DiffusionQuantizationConfig | dict[str, Any] | None" = None + def settle_port(self, port: int, port_inc: int = 42, max_attempts: int = 100) -> int: """ Find an available port with retry logic. @@ -451,6 +462,38 @@ def __post_init__(self): # If it's neither dict nor DiffusionCacheConfig, convert to empty config self.cache_config = DiffusionCacheConfig() + # Convert quantization config (deferred import to avoid circular imports) + if self.quantization is not None or self.quantization_config is not None: + from vllm_omni.diffusion.quantization import ( + DiffusionQuantizationConfig, + get_diffusion_quant_config, + ) + + # Handle dict or DictConfig (from OmegaConf) - use Mapping for broader compatibility + if isinstance(self.quantization_config, Mapping): + # Convert DictConfig to dict if needed (OmegaConf compatibility) + config_dict = dict(self.quantization_config) + # Use get() instead of pop() to avoid mutating original dict + quant_method = config_dict.get("method", self.quantization) + # Filter out "method" key for kwargs + quant_kwargs = {k: v for k, v in config_dict.items() if k != "method"} + + # Validate conflicting methods + if self.quantization is not None and quant_method is not None and quant_method != self.quantization: + logger.warning( + f"Conflicting quantization methods: quantization={self.quantization!r}, " + f"quantization_config['method']={quant_method!r}. Using quantization_config['method']." + ) + + self.quantization_config = get_diffusion_quant_config(quant_method, **quant_kwargs) + elif self.quantization_config is None and self.quantization is not None: + self.quantization_config = get_diffusion_quant_config(self.quantization) + elif not isinstance(self.quantization_config, DiffusionQuantizationConfig): + raise TypeError( + f"quantization_config must be a DiffusionQuantizationConfig, dict, or None, " + f"got {type(self.quantization_config)!r}" + ) + if self.max_cpu_loras is None: self.max_cpu_loras = 1 elif self.max_cpu_loras < 1: diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 892954ce9e..b61f70b697 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -13,6 +13,7 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -217,8 +218,36 @@ def load_model(self, od_config: OmniDiffusionConfig, load_device: str) -> nn.Mod logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it self.load_weights(model) + + # Process weights after loading for quantization (e.g., FP8 online quantization) + # This is needed for vLLM's quantization methods that need to transform weights + self._process_weights_after_loading(model, target_device) + return model.eval() + def _process_weights_after_loading(self, model: nn.Module, target_device: torch.device) -> None: + """Process weights after loading for quantization methods. + + This handles vLLM's quantization methods that need to process weights + after loading (e.g., FP8 online quantization from BF16/FP16 weights). + """ + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # Move module to target device for processing if needed + module_device = next(module.parameters(), None) + if module_device is not None: + module_device = module_device.device + needs_device_move = module_device != target_device + + if needs_device_move: + module.to(target_device) + + quant_method.process_weights_after_loading(module) + + if needs_device_move: + module.to(module_device) + def load_weights(self, model: nn.Module) -> None: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(self.get_all_weights(model)) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index d85d98b5bf..a610443041 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -33,6 +33,7 @@ from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import ( @@ -272,7 +273,10 @@ def __init__( self.device ) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) - self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.transformer = QwenImageTransformer2DModel( + od_config=od_config, quant_config=quant_config, **transformer_kwargs + ) self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 5f2b8c68a2..cc839bd74e 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from functools import lru_cache from math import prod -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -23,6 +23,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + from vllm_omni.diffusion.attention.backends.abstract import ( AttentionMetadata, ) @@ -390,7 +395,16 @@ def _compute_video_freqs(self, frame, height, width, idx=0): class ColumnParallelApproxGELU(nn.Module): - def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + def __init__( + self, + dim_in: int, + dim_out: int, + *, + approximate: str, + bias: bool = True, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", + ): super().__init__() self.proj = ColumnParallelLinear( dim_in, @@ -398,6 +412,8 @@ def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = bias=bias, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=prefix, ) self.approximate = approximate @@ -415,6 +431,8 @@ def __init__( activation_fn: str = "gelu-approximate", inner_dim: int | None = None, bias: bool = True, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ) -> None: super().__init__() @@ -424,13 +442,17 @@ def __init__( dim_out = dim_out or dim layers: list[nn.Module] = [ - ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + ColumnParallelApproxGELU( + dim, inner_dim, approximate="tanh", bias=bias, quant_config=quant_config, prefix=prefix + ), nn.Identity(), # placeholder for weight loading RowParallelLinear( inner_dim, dim_out, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=prefix, ), ] @@ -456,6 +478,7 @@ def __init__( pre_only: bool = False, context_pre_only: bool = False, out_dim: int | None = None, + quant_config: "QuantizationConfig | None" = None, ) -> None: super().__init__() assert dim % num_heads == 0 @@ -471,6 +494,8 @@ def __init__( hidden_size=dim, head_size=self.head_dim, total_num_heads=num_heads, + quant_config=quant_config, + prefix="to_qkv", ) self.query_num_heads = self.to_qkv.num_heads self.kv_num_heads = self.to_qkv.num_kv_heads @@ -485,6 +510,8 @@ def __init__( hidden_size=added_kv_proj_dim, head_size=head_dim, total_num_heads=num_heads, + quant_config=quant_config, + prefix="add_kv_proj", ) self.add_query_num_heads = self.add_kv_proj.num_heads self.add_kv_num_heads = self.add_kv_proj.num_kv_heads @@ -496,6 +523,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix="to_add_out", ) assert not pre_only @@ -505,6 +534,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix="to_out", ) self.norm_added_q = RMSNorm(head_dim, eps=eps) @@ -637,6 +668,7 @@ def __init__( qk_norm: str = "rms_norm", eps: float = 1e-6, zero_cond_t: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -656,9 +688,10 @@ def __init__( added_kv_proj_dim=dim, context_pre_only=False, head_dim=attention_head_dim, + quant_config=quant_config, ) self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.img_mlp = FeedForward(dim=dim, dim_out=dim) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="img_mlp") # Text processing modules self.txt_mod = nn.Sequential( @@ -668,7 +701,7 @@ def __init__( self.txt_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) # Text doesn't need separate attention - it's handled by img_attn joint computation self.txt_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.txt_mlp = FeedForward(dim=dim, dim_out=dim) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="txt_mlp") self.zero_cond_t = zero_cond_t @@ -862,6 +895,7 @@ def __init__( zero_cond_t: bool = False, use_additional_t_cond: bool = False, use_layer3d_rope: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.parallel_config = od_config.parallel_config @@ -891,6 +925,7 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, zero_cond_t=zero_cond_t, + quant_config=quant_config, ) for _ in range(num_layers) ] diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 2c938f39bf..8a39131953 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -37,6 +37,7 @@ from vllm_omni.diffusion.models.z_image.z_image_transformer import ( ZImageTransformer2DModel, ) +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, @@ -173,7 +174,9 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self._execution_device ) - self.transformer = ZImageTransformer2DModel() + # Get vLLM quantization config for linear layers + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.transformer = ZImageTransformer2DModel(quant_config=quant_config) self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) # Note: Context parallelism is applied centrally in registry.initialize_model() diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 8efb899312..67e6723438 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -18,6 +18,7 @@ import math from collections.abc import Iterable +from typing import TYPE_CHECKING import torch import torch.nn as nn @@ -32,6 +33,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.cache.base import CachedTransformer from vllm_omni.diffusion.distributed.sp_plan import ( @@ -250,6 +256,7 @@ def __init__( num_kv_heads: int, qk_norm: bool = True, eps: float = 1e-6, + quant_config: "QuantizationConfig | None" = None, ) -> None: super().__init__() self.dim = dim @@ -264,6 +271,7 @@ def __init__( total_num_heads=num_heads, total_num_kv_heads=num_kv_heads, bias=False, + quant_config=quant_config, ) assert qk_norm is True @@ -281,6 +289,7 @@ def __init__( bias=False, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) ] ) @@ -343,13 +352,19 @@ def forward( class FeedForward(nn.Module): - def __init__(self, dim: int, hidden_dim: int): + def __init__( + self, + dim: int, + hidden_dim: int, + quant_config: "QuantizationConfig | None" = None, + ): super().__init__() self.w13 = MergedColumnParallelLinear( dim, [hidden_dim] * 2, bias=False, return_bias=False, + quant_config=quant_config, ) self.act = SiluAndMul() self.w2 = RowParallelLinear( @@ -358,6 +373,7 @@ def __init__(self, dim: int, hidden_dim: int): bias=False, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) def forward(self, x): @@ -374,6 +390,7 @@ def __init__( norm_eps: float, qk_norm: bool, modulation=True, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.dim = dim @@ -384,9 +401,14 @@ def __init__( num_kv_heads=n_kv_heads, qk_norm=qk_norm, eps=1e-5, + quant_config=quant_config, ) - self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=int(dim / 3 * 8), + quant_config=quant_config, + ) self.layer_id = layer_id self.attention_norm1 = RMSNorm(dim, eps=norm_eps) @@ -589,6 +611,7 @@ def __init__( t_scale=1000.0, axes_dims=[32, 48, 48], axes_lens=[1024, 512, 512], + quant_config: "QuantizationConfig | None" = None, ) -> None: super().__init__() self.dtype = torch.bfloat16 @@ -648,6 +671,7 @@ def __init__( norm_eps, qk_norm, modulation=True, + quant_config=quant_config, ) for layer_id in range(n_refiner_layers) ] @@ -662,6 +686,7 @@ def __init__( norm_eps, qk_norm, modulation=False, + quant_config=quant_config, ) for layer_id in range(n_refiner_layers) ] @@ -677,7 +702,15 @@ def __init__( self.layers = nn.ModuleList( [ - ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + quant_config=quant_config, + ) for layer_id in range(n_layers) ] ) diff --git a/vllm_omni/diffusion/quantization/__init__.py b/vllm_omni/diffusion/quantization/__init__.py new file mode 100644 index 0000000000..cc1bb547f7 --- /dev/null +++ b/vllm_omni/diffusion/quantization/__init__.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantization support for diffusion models. + +This module provides a unified interface for quantizing diffusion transformers +using various methods (FP8, etc.). It wraps vLLM's quantization infrastructure +while allowing diffusion-model-specific defaults and optimizations. + +Example usage: + from vllm_omni.diffusion.quantization import ( + get_diffusion_quant_config, + get_vllm_quant_config_for_layers, + ) + + # Create FP8 config for diffusion model + diff_config = get_diffusion_quant_config("fp8") + + # Get vLLM config to pass to linear layers + vllm_config = get_vllm_quant_config_for_layers(diff_config) + + # Use in model initialization + linear_layer = QKVParallelLinear(..., quant_config=vllm_config) +""" + +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +from .base import DiffusionQuantizationConfig +from .fp8 import DiffusionFp8Config + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + +logger = init_logger(__name__) + +# Registry of supported quantization methods +# To add a new method, create a new config class and register it here +_QUANT_CONFIG_REGISTRY: dict[str, type[DiffusionQuantizationConfig]] = { + "fp8": DiffusionFp8Config, +} + +SUPPORTED_QUANTIZATION_METHODS = list(_QUANT_CONFIG_REGISTRY.keys()) + + +def get_diffusion_quant_config( + quantization: str | None, + **kwargs, +) -> DiffusionQuantizationConfig | None: + """Factory function to create quantization config for diffusion models. + + Args: + quantization: Quantization method name ("fp8", etc.) or None to disable + **kwargs: Method-specific parameters passed to the config constructor + + Returns: + DiffusionQuantizationConfig instance or None if quantization is disabled + + Raises: + ValueError: If the quantization method is not supported + + Example: + # Default FP8 with dynamic activation scaling + config = get_diffusion_quant_config("fp8") + + # FP8 with custom parameters + config = get_diffusion_quant_config( + "fp8", + activation_scheme="static", + ignored_layers=["proj_out"], + ) + """ + if quantization is None or quantization.lower() == "none": + return None + + quantization = quantization.lower() + if quantization not in _QUANT_CONFIG_REGISTRY: + raise ValueError( + f"Unknown quantization method: {quantization!r}. Supported methods: {SUPPORTED_QUANTIZATION_METHODS}" + ) + + config_cls = _QUANT_CONFIG_REGISTRY[quantization] + logger.info("Creating diffusion quantization config: %s", quantization) + return config_cls(**kwargs) + + +def get_vllm_quant_config_for_layers( + diffusion_quant_config: DiffusionQuantizationConfig | None, +) -> "QuantizationConfig | None": + """Get the vLLM QuantizationConfig to pass to linear layers. + + This extracts the underlying vLLM config from a DiffusionQuantizationConfig, + which can then be passed to vLLM linear layers (QKVParallelLinear, etc.). + + Args: + diffusion_quant_config: The diffusion quantization config, or None + + Returns: + vLLM QuantizationConfig instance, or None if input is None + """ + if diffusion_quant_config is None: + return None + return diffusion_quant_config.get_vllm_quant_config() + + +__all__ = [ + "DiffusionQuantizationConfig", + "DiffusionFp8Config", + "get_diffusion_quant_config", + "get_vllm_quant_config_for_layers", + "SUPPORTED_QUANTIZATION_METHODS", +] diff --git a/vllm_omni/diffusion/quantization/base.py b/vllm_omni/diffusion/quantization/base.py new file mode 100644 index 0000000000..0cd9e4147e --- /dev/null +++ b/vllm_omni/diffusion/quantization/base.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for diffusion model quantization configurations.""" + +from abc import ABC +from typing import TYPE_CHECKING, ClassVar + +import torch + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + + +class DiffusionQuantizationConfig(ABC): + """Base class for diffusion model quantization configurations. + + This provides a thin wrapper over vLLM's quantization configs, + allowing diffusion-model-specific defaults and future extensibility. + + Subclasses should: + - Set quant_config_cls to the vLLM QuantizationConfig class + - Call super().__init__() after creating self._vllm_config + - Optionally override get_name() and get_min_capability() if needed + """ + + # Subclasses should set this to the vLLM QuantizationConfig class + quant_config_cls: ClassVar[type["QuantizationConfig"] | None] = None + + # The underlying vLLM config instance + _vllm_config: "QuantizationConfig | None" = None + + @classmethod + def get_name(cls) -> str: + """Return the quantization method name (e.g., 'fp8', 'int8'). + + By default, delegates to the underlying vLLM config class. + """ + if cls.quant_config_cls is not None: + return cls.quant_config_cls.get_name() + raise NotImplementedError("Subclass must set quant_config_cls or override get_name()") + + def get_vllm_quant_config(self) -> "QuantizationConfig | None": + """Return the underlying vLLM QuantizationConfig for linear layers.""" + return self._vllm_config + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + """Return supported activation dtypes.""" + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + """Minimum GPU compute capability required. + + By default, delegates to the underlying vLLM config class. + """ + if cls.quant_config_cls is not None: + return cls.quant_config_cls.get_min_capability() + return 80 # Ampere default diff --git a/vllm_omni/diffusion/quantization/fp8.py b/vllm_omni/diffusion/quantization/fp8.py new file mode 100644 index 0000000000..68abf9c229 --- /dev/null +++ b/vllm_omni/diffusion/quantization/fp8.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FP8 quantization config for diffusion transformers.""" + +from vllm.model_executor.layers.quantization.fp8 import Fp8Config + +from .base import DiffusionQuantizationConfig + + +class DiffusionFp8Config(DiffusionQuantizationConfig): + """FP8 quantization config optimized for diffusion transformers. + + Uses dynamic activation scaling (no calibration dataset needed) and + online weight quantization from BF16/FP16 checkpoints. + + Device Compatibility: + - Turing (SM 75+): Weight-only FP8 via Marlin kernel + - Ada/Hopper (SM 89+): Full W8A8 FP8 with native hardware support + + The kernel selection is automatic based on GPU capability. + + Args: + activation_scheme: Activation quantization scheme. + - "dynamic": Per-token dynamic scaling (default, no calibration) + - "static": Single per-tensor scale (requires calibration) + weight_block_size: Block size for block-wise weight quantization. + Format: [block_n, block_k]. If None, uses per-tensor scaling. + ignored_layers: List of layer name patterns to skip quantization. + """ + + # Tight coupling with vLLM's Fp8Config - delegates get_name() and get_min_capability() + quant_config_cls = Fp8Config + + def __init__( + self, + activation_scheme: str = "dynamic", + weight_block_size: list[int] | None = None, + ignored_layers: list[str] | None = None, + ): + self.activation_scheme = activation_scheme + self.weight_block_size = weight_block_size + self.ignored_layers = ignored_layers or [] + + # Create underlying vLLM FP8 config + self._vllm_config = Fp8Config( + is_checkpoint_fp8_serialized=False, # Online quantization from BF16 + activation_scheme=activation_scheme, + weight_block_size=weight_block_size, + ignored_layers=ignored_layers, + ) From b880190e8be3d276ec4c269342bceed15c47615a Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Tue, 10 Feb 2026 16:20:28 +0800 Subject: [PATCH 02/62] support gguf fp8 1 Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/diffusers_loader.py | 216 +++++++++++++++++- .../flux2_klein/flux2_klein_transformer.py | 33 ++- .../flux2_klein/pipeline_flux2_klein.py | 4 +- vllm_omni/diffusion/quantization/__init__.py | 3 + vllm_omni/diffusion/quantization/fp8.py | 4 +- vllm_omni/diffusion/quantization/gguf.py | 32 +++ 6 files changed, 284 insertions(+), 8 deletions(-) create mode 100644 vllm_omni/diffusion/quantization/gguf.py diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index b61f70b697..accc61d325 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -6,19 +6,22 @@ import time from collections.abc import Generator, Iterable from pathlib import Path -from typing import cast +from typing import Callable, cast import torch from torch import nn +from huggingface_hub import hf_hub_download from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.model_loader.weight_utils import ( + download_gguf, download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + gguf_quant_weights_iterator, maybe_download_from_modelscope, safetensors_weights_iterator, ) @@ -216,8 +219,11 @@ def load_model(self, od_config: OmniDiffusionConfig, load_device: str) -> nn.Mod model = initialize_model(od_config) logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it - self.load_weights(model) + if self._is_gguf_quantization(od_config): + self._load_weights_with_gguf(model, od_config) + else: + # Quantization does not happen in `load_weights` but after it + self.load_weights(model) # Process weights after loading for quantization (e.g., FP8 online quantization) # This is needed for vLLM's quantization methods that need to transform weights @@ -268,3 +274,207 @@ def load_weights(self, model: nn.Module) -> None: # "Following weights were not initialized from " # f"checkpoint: {weights_not_loaded}" # ) + + def _is_gguf_quantization(self, od_config: OmniDiffusionConfig) -> bool: + quant_config = od_config.quantization_config + if quant_config is None: + return False + try: + is_gguf = quant_config.get_name() == "gguf" + except Exception: + return False + if not is_gguf: + return False + gguf_model = getattr(quant_config, "gguf_model", None) + if gguf_model is None: + raise ValueError("GGUF quantization requires quantization_config.gguf_model") + return True + + def _is_transformer_source(self, source: "ComponentSource") -> bool: + if source.subfolder == "transformer": + return True + return source.prefix.startswith("transformer.") + + def _get_model_loadable_names(self, model: nn.Module) -> set[str]: + # Use state_dict keys to include both parameters and buffers. + return set(model.state_dict().keys()) + + def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str: + if os.path.isfile(gguf_model): + return gguf_model + # raw HTTPS link + if gguf_model.startswith(("http://", "https://")) and gguf_model.endswith(".gguf"): + return hf_hub_download(url=gguf_model) + # repo_id/filename.gguf + if "/" in gguf_model and gguf_model.endswith(".gguf"): + repo_id, filename = gguf_model.rsplit("/", 1) + return hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + cache_dir=self.load_config.download_dir, + ) + # repo_id:quant_type + if "/" in gguf_model and ":" in gguf_model: + repo_id, quant_type = gguf_model.rsplit(":", 1) + return download_gguf( + repo_id, + quant_type, + cache_dir=self.load_config.download_dir, + revision=revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + raise ValueError( + f"Unrecognized GGUF reference: {gguf_model!r} (expected local file, " + "raw URL, /.gguf, or :)" + ) + + def _get_gguf_name_mapper(self, od_config: OmniDiffusionConfig) -> Callable[[str], str | None]: + model_class = od_config.model_class_name + if model_class in { + "QwenImagePipeline", + "QwenImageEditPipeline", + "QwenImageEditPlusPipeline", + "QwenImageLayeredPipeline", + }: + return lambda name: name + if model_class == "Flux2KleinPipeline": + return self._map_flux2_klein_gguf_name + raise ValueError(f"GGUF mapping is not implemented for model_class_name={model_class!r}") + + @staticmethod + def _map_flux2_klein_gguf_name(name: str) -> str | None: + if name.startswith("double_stream_modulation_img.lin."): + return name.replace("double_stream_modulation_img.lin.", "double_stream_modulation_img.linear.", 1) + if name.startswith("double_stream_modulation_txt.lin."): + return name.replace("double_stream_modulation_txt.lin.", "double_stream_modulation_txt.linear.", 1) + if name.startswith("single_stream_modulation.lin."): + return name.replace("single_stream_modulation.lin.", "single_stream_modulation.linear.", 1) + if name.startswith("img_in."): + return name.replace("img_in.", "x_embedder.", 1) + if name.startswith("txt_in."): + return name.replace("txt_in.", "context_embedder.", 1) + if name.startswith("time_in.in_layer."): + return name.replace( + "time_in.in_layer.", + "time_guidance_embed.timestep_embedder.linear_1.", + 1, + ) + if name.startswith("time_in.out_layer."): + return name.replace( + "time_in.out_layer.", + "time_guidance_embed.timestep_embedder.linear_2.", + 1, + ) + if name.startswith("final_layer.adaLN_modulation.1."): + return name.replace("final_layer.adaLN_modulation.1.", "norm_out.linear.", 1) + if name.startswith("final_layer.linear."): + return name.replace("final_layer.linear.", "proj_out.", 1) + + if name.startswith("double_blocks."): + name = name.replace("double_blocks.", "transformer_blocks.", 1) + if ".img_attn.qkv." in name: + return name.replace(".img_attn.qkv.", ".attn.to_qkv.", 1) + if ".img_attn.proj." in name: + return name.replace(".img_attn.proj.", ".attn.to_out.0.", 1) + if name.endswith(".img_attn.norm.query_norm.scale"): + return name.replace(".img_attn.norm.query_norm.scale", ".attn.norm_q.weight", 1) + if name.endswith(".img_attn.norm.key_norm.scale"): + return name.replace(".img_attn.norm.key_norm.scale", ".attn.norm_k.weight", 1) + if ".txt_attn.qkv." in name: + return name.replace(".txt_attn.qkv.", ".attn.add_kv_proj.", 1) + if ".txt_attn.proj." in name: + return name.replace(".txt_attn.proj.", ".attn.to_add_out.", 1) + if name.endswith(".txt_attn.norm.query_norm.scale"): + return name.replace(".txt_attn.norm.query_norm.scale", ".attn.norm_added_q.weight", 1) + if name.endswith(".txt_attn.norm.key_norm.scale"): + return name.replace(".txt_attn.norm.key_norm.scale", ".attn.norm_added_k.weight", 1) + if ".img_mlp.0." in name: + return name.replace(".img_mlp.0.", ".ff.linear_in.", 1) + if ".img_mlp.2." in name: + return name.replace(".img_mlp.2.", ".ff.linear_out.", 1) + if ".txt_mlp.0." in name: + return name.replace(".txt_mlp.0.", ".ff_context.linear_in.", 1) + if ".txt_mlp.2." in name: + return name.replace(".txt_mlp.2.", ".ff_context.linear_out.", 1) + return None + + if name.startswith("single_blocks."): + name = name.replace("single_blocks.", "single_transformer_blocks.", 1) + if ".linear1." in name: + return name.replace(".linear1.", ".attn.to_qkv_mlp_proj.", 1) + if ".linear2." in name: + return name.replace(".linear2.", ".attn.to_out.", 1) + if name.endswith(".norm.query_norm.scale"): + return name.replace(".norm.query_norm.scale", ".attn.norm_q.weight", 1) + if name.endswith(".norm.key_norm.scale"): + return name.replace(".norm.key_norm.scale", ".attn.norm_k.weight", 1) + return None + + return None + + def _build_gguf_name_map( + self, + gguf_file: str, + od_config: OmniDiffusionConfig, + ) -> dict[str, str]: + try: + import gguf # type: ignore + except Exception as exc: # pragma: no cover - dependency error + raise RuntimeError( + "GGUF support requires the 'gguf' package to be installed." + ) from exc + + mapper = self._get_gguf_name_mapper(od_config) + reader = gguf.GGUFReader(gguf_file) + gguf_to_model_map: dict[str, str] = {} + for tensor in reader.tensors: + mapped = mapper(tensor.name) + if mapped is None: + continue + gguf_to_model_map[tensor.name] = mapped + if not gguf_to_model_map: + raise RuntimeError( + f"No GGUF tensors were mapped for model_class_name={od_config.model_class_name!r}." + ) + return gguf_to_model_map + + def _get_gguf_weights_iterator( + self, + source: "ComponentSource", + od_config: OmniDiffusionConfig, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + quant_config = od_config.quantization_config + gguf_model = getattr(quant_config, "gguf_model", None) + if gguf_model is None: + raise ValueError("GGUF quantization requires quantization_config.gguf_model") + gguf_file = self._resolve_gguf_model_path(gguf_model, od_config.revision) + gguf_name_map = self._build_gguf_name_map(gguf_file, od_config) + weights_iter = gguf_quant_weights_iterator(gguf_file, gguf_name_map) + return ((source.prefix + name, tensor) for (name, tensor) in weights_iter) + + def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConfig) -> set[str]: + sources = cast( + Iterable[DiffusersPipelineLoader.ComponentSource], + getattr(model, "weights_sources", ()), + ) + loaded: set[str] = set() + loadable_names: set[str] | None = None + + for source in sources: + if self._is_transformer_source(source): + loaded |= model.load_weights(self._get_gguf_weights_iterator(source, od_config)) + + # Load any remaining float weights (e.g., non-quantized layers) + # from the base HF checkpoint while skipping already-loaded names. + loadable_names = loadable_names or self._get_model_loadable_names(model) + hf_iter = self._get_weights_iterator(source) + hf_iter = ( + (name, tensor) + for (name, tensor) in hf_iter + if name in loadable_names and name not in loaded + ) + loaded |= model.load_weights(hf_iter) + else: + loaded |= model.load_weights(self._get_weights_iterator(source)) + return loaded diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ee10d2e0e4..bdf542003e 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -17,7 +17,7 @@ from collections.abc import Iterable from types import SimpleNamespace -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -41,6 +41,9 @@ from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.layers.rope import RotaryEmbedding +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + class Flux2SwiGLU(nn.Module): """SwiGLU activation used by Flux2.""" @@ -62,6 +65,7 @@ def __init__( mult: float = 3.0, inner_dim: int | None = None, bias: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() if inner_dim is None: @@ -73,6 +77,7 @@ def __init__( [inner_dim, inner_dim], bias=bias, return_bias=False, + quant_config=quant_config, ) self.act_fn = Flux2SwiGLU() self.linear_out = RowParallelLinear( @@ -81,6 +86,7 @@ def __init__( bias=bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -103,6 +109,7 @@ def __init__( eps: float = 1e-5, out_dim: int = None, elementwise_affine: bool = True, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.head_dim = dim_head @@ -118,6 +125,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=bias, + quant_config=quant_config, ) self.query_num_heads = self.to_qkv.num_heads self.kv_num_heads = self.to_qkv.num_kv_heads @@ -133,6 +141,7 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ), nn.Dropout(dropout), ] @@ -146,6 +155,7 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=added_proj_bias, + quant_config=quant_config, ) self.add_query_num_heads = self.add_kv_proj.num_heads self.add_kv_num_heads = self.add_kv_proj.num_kv_heads @@ -155,6 +165,7 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) self.rope = RotaryEmbedding(is_neox_style=False) @@ -251,6 +262,7 @@ def __init__( elementwise_affine: bool = True, mlp_ratio: float = 4.0, mlp_mult_factor: int = 2, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.head_dim = dim_head @@ -269,6 +281,7 @@ def __init__( self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias, gather_output=True, + quant_config=quant_config, ) self.mlp_act_fn = Flux2SwiGLU() @@ -280,6 +293,7 @@ def __init__( self.out_dim, bias=out_bias, gather_output=True, + quant_config=quant_config, ) self.rope = RotaryEmbedding(is_neox_style=False) self.attn = Attention( @@ -342,6 +356,7 @@ def __init__( mlp_ratio: float = 3.0, eps: float = 1e-6, bias: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) @@ -355,6 +370,7 @@ def __init__( eps=eps, mlp_ratio=mlp_ratio, mlp_mult_factor=2, + quant_config=quant_config, ) def forward( @@ -402,6 +418,7 @@ def __init__( mlp_ratio: float = 3.0, eps: float = 1e-6, bias: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) @@ -417,13 +434,20 @@ def __init__( added_proj_bias=bias, out_bias=bias, eps=eps, + quant_config=quant_config, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) - self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias, quant_config=quant_config) self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) - self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias) + self.ff_context = Flux2FeedForward( + dim=dim, + dim_out=dim, + mult=mlp_ratio, + bias=bias, + quant_config=quant_config, + ) def forward( self, @@ -580,6 +604,7 @@ def __init__( rope_theta: int = 2000, eps: float = 1e-6, guidance_embeds: bool = True, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.out_channels = out_channels or in_channels @@ -625,6 +650,7 @@ def __init__( mlp_ratio=mlp_ratio, eps=eps, bias=False, + quant_config=quant_config, ) for _ in range(num_layers) ] @@ -639,6 +665,7 @@ def __init__( mlp_ratio=mlp_ratio, eps=eps, bias=False, + quant_config=quant_config, ) for _ in range(num_single_layers) ] diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index e1ef706c3f..09df716ada 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -42,6 +42,7 @@ from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( Flux2Transformer2DModel, ) +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs @@ -230,7 +231,8 @@ def __init__( ).to(self._execution_device) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, Flux2Transformer2DModel) - self.transformer = Flux2Transformer2DModel(**transformer_kwargs) + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.transformer = Flux2Transformer2DModel(quant_config=quant_config, **transformer_kwargs) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) diff --git a/vllm_omni/diffusion/quantization/__init__.py b/vllm_omni/diffusion/quantization/__init__.py index cc1bb547f7..d297d51f18 100644 --- a/vllm_omni/diffusion/quantization/__init__.py +++ b/vllm_omni/diffusion/quantization/__init__.py @@ -28,6 +28,7 @@ from .base import DiffusionQuantizationConfig from .fp8 import DiffusionFp8Config +from .gguf import DiffusionGgufConfig if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import ( @@ -40,6 +41,7 @@ # To add a new method, create a new config class and register it here _QUANT_CONFIG_REGISTRY: dict[str, type[DiffusionQuantizationConfig]] = { "fp8": DiffusionFp8Config, + "gguf": DiffusionGgufConfig, } SUPPORTED_QUANTIZATION_METHODS = list(_QUANT_CONFIG_REGISTRY.keys()) @@ -108,6 +110,7 @@ def get_vllm_quant_config_for_layers( __all__ = [ "DiffusionQuantizationConfig", "DiffusionFp8Config", + "DiffusionGgufConfig", "get_diffusion_quant_config", "get_vllm_quant_config_for_layers", "SUPPORTED_QUANTIZATION_METHODS", diff --git a/vllm_omni/diffusion/quantization/fp8.py b/vllm_omni/diffusion/quantization/fp8.py index 68abf9c229..963dd6c3bc 100644 --- a/vllm_omni/diffusion/quantization/fp8.py +++ b/vllm_omni/diffusion/quantization/fp8.py @@ -36,14 +36,16 @@ def __init__( activation_scheme: str = "dynamic", weight_block_size: list[int] | None = None, ignored_layers: list[str] | None = None, + is_checkpoint_fp8_serialized: bool = False, ): self.activation_scheme = activation_scheme self.weight_block_size = weight_block_size self.ignored_layers = ignored_layers or [] + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized # Create underlying vLLM FP8 config self._vllm_config = Fp8Config( - is_checkpoint_fp8_serialized=False, # Online quantization from BF16 + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, weight_block_size=weight_block_size, ignored_layers=ignored_layers, diff --git a/vllm_omni/diffusion/quantization/gguf.py b/vllm_omni/diffusion/quantization/gguf.py new file mode 100644 index 0000000000..fcc4498bfe --- /dev/null +++ b/vllm_omni/diffusion/quantization/gguf.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GGUF quantization config for diffusion transformers.""" + +from vllm.model_executor.layers.quantization.gguf import GGUFConfig + +from .base import DiffusionQuantizationConfig + + +class DiffusionGgufConfig(DiffusionQuantizationConfig): + """GGUF quantization config for diffusion transformers. + + This is a thin wrapper around vLLM's GGUFConfig and also carries + the GGUF model reference for loader use. + + Args: + gguf_model: GGUF model path or HF reference (repo/file or repo:quant_type) + unquantized_modules: Optional list of module name patterns to skip GGUF + quantization. Note: diffusion linear layers often use short prefixes + (e.g., "to_qkv"), so these patterns are matched as substrings. + """ + + quant_config_cls = GGUFConfig + + def __init__( + self, + gguf_model: str | None = None, + unquantized_modules: list[str] | None = None, + ) -> None: + self.gguf_model = gguf_model + self.unquantized_modules = unquantized_modules or [] + self._vllm_config = GGUFConfig(unquantized_modules=self.unquantized_modules) From 2a208d84440f889f4453d6f919e28e7349499d73 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Tue, 10 Feb 2026 16:22:09 +0800 Subject: [PATCH 03/62] support gguf fp8 2 Signed-off-by: David Chen <530634352@qq.com> --- .../text_to_image/text_to_image.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index adca5f6f8b..22a83a17c9 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -129,11 +129,21 @@ def parse_args() -> argparse.Namespace: "--quantization", type=str, default=None, - choices=["fp8"], + choices=["fp8", "gguf"], help="Quantization method for the transformer. " - "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). " + "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs), " + "'gguf' (load transformer weights from a GGUF file). " "Default: None (no quantization, uses BF16).", ) + parser.add_argument( + "--gguf-model", + type=str, + default=None, + help=( + "GGUF file path or HF reference for transformer weights. " + "Required when --quantization gguf is set." + ), + ) parser.add_argument( "--ignored-layers", type=str, @@ -204,7 +214,14 @@ def main(): # ignored_layers is specified so the list flows through OmniDiffusionConfig quant_kwargs: dict[str, Any] = {} ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None - if args.quantization and ignored_layers: + if args.quantization == "gguf": + if not args.gguf_model: + raise ValueError("--gguf-model is required when --quantization gguf is set.") + quant_kwargs["quantization_config"] = { + "method": "gguf", + "gguf_model": args.gguf_model, + } + elif args.quantization and ignored_layers: quant_kwargs["quantization_config"] = { "method": args.quantization, "ignored_layers": ignored_layers, From d81c2a90f5fa630edaca1e4c48fca1ad2f344807 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Tue, 10 Feb 2026 16:33:29 +0800 Subject: [PATCH 04/62] support gguf fp8 3 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/model_loader/diffusers_loader.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index accc61d325..77915da61e 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -279,10 +279,23 @@ def _is_gguf_quantization(self, od_config: OmniDiffusionConfig) -> bool: quant_config = od_config.quantization_config if quant_config is None: return False + # Fast path: mapping-style config (e.g., DictConfig) + if isinstance(quant_config, dict): + method = str(quant_config.get("method", "")).lower() + if method != "gguf": + return False + gguf_model = quant_config.get("gguf_model") + if not gguf_model: + raise ValueError("GGUF quantization requires quantization_config.gguf_model") + return True + + # Normal path: DiffusionQuantizationConfig try: is_gguf = quant_config.get_name() == "gguf" except Exception: - return False + # Fallback: if it carries gguf_model, treat as GGUF + gguf_model = getattr(quant_config, "gguf_model", None) + return bool(gguf_model) if not is_gguf: return False gguf_model = getattr(quant_config, "gguf_model", None) From cdd0dfb21c22f52eea6ff0bae92f68015ef40808 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Tue, 10 Feb 2026 16:36:50 +0800 Subject: [PATCH 05/62] support gguf fp8 4 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/models/flux2_klein/flux2_klein_transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index bdf542003e..5ac0b4e1a9 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -779,8 +779,11 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight) loaded_params.add(name) continue + # GGUF fused QKV weights already target .to_qkv/.add_kv_proj. + # Avoid substring replacement that would duplicate "qkv". + is_fused_qkv = ".to_qkv." in name or ".add_kv_proj." in name for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if is_fused_qkv or weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] From 887d1c055a0d4787825e2c443d119cebe75ca4b1 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Tue, 10 Feb 2026 16:41:13 +0800 Subject: [PATCH 06/62] support gguf fp8 5 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/diffusers_loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 77915da61e..4f8bbdf7ed 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -309,8 +309,11 @@ def _is_transformer_source(self, source: "ComponentSource") -> bool: return source.prefix.startswith("transformer.") def _get_model_loadable_names(self, model: nn.Module) -> set[str]: - # Use state_dict keys to include both parameters and buffers. - return set(model.state_dict().keys()) + # Avoid model.state_dict() here because GGUF uses UninitializedParameter + # which raises during detach(). Collect names directly. + names = {name for name, _ in model.named_parameters()} + names.update(name for name, _ in model.named_buffers()) + return names def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str: if os.path.isfile(gguf_model): From 4d77a92aba49ec4b2e6d8956db55eba9cb5a9cc8 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 10:14:21 +0800 Subject: [PATCH 07/62] support gguf fp8 add design doc Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/quantization/gguf_fp8_design.md | 274 ++++++++++++++++++ 1 file changed, 274 insertions(+) create mode 100644 docs/user_guide/diffusion/quantization/gguf_fp8_design.md diff --git a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md new file mode 100644 index 0000000000..92af0325b3 --- /dev/null +++ b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md @@ -0,0 +1,274 @@ +# Diffusion Quantization Design: Native GGUF + FP8 (Online and Native) + +Date: 2026-02-10 + +## Goals +1. Reuse vLLM quantization configs and weight loaders as much as possible. +2. Add native GGUF and FP8 support to diffusion transformers without changing model definitions. +3. Keep user-facing knobs minimal and consistent across offline and online flows. + +## Scope +1. Models: Qwen-Image and Flux2-klein are first-class targets. +2. Components: diffusion transformer weights, loader paths, and quantization configs. +3. Modes: native GGUF, online FP8, native FP8 (pre-serialized FP8 checkpoint). + +## Architecture Overview +1. `OmniDiffusionConfig` accepts `quantization` or `quantization_config`. +2. Diffusion quantization wrappers (`DiffusionGgufConfig`, `DiffusionFp8Config`) produce vLLM `QuantizationConfig` objects for linear layers. +3. `DiffusersPipelineLoader` branches on quantization method and loads either HF weights or GGUF weights for the transformer. +4. vLLM GGUF path uses `GGUFConfig` and `GGUFLinearMethod` for matmul; FP8 uses `Fp8Config` (online) or `is_checkpoint_fp8_serialized` for native FP8. + +## Call Chain (Offline) +``` +CLI (examples/offline_inference/text_to_image/text_to_image.py) + | + v +Omni (vllm_omni/entrypoints/omni.py) + | + v +OmniStage (diffusion) + | + v +DiffusionWorker + | + v +DiffusionModelRunner + | + v +DiffusersPipelineLoader + | + v +Pipeline.forward (Flux2/Qwen) + | + v +DiffusionEngine + | + v +OmniRequestOutput + | + v +Client (saved PNG) +``` + +## Call Chain (Online) +``` +Client + | + | POST /v1/images/generations + v +APIServer (vllm_omni/entrypoints/openai/api_server.py) + | + v +_generate_with_async_omni + | + v +AsyncOmni + | + v +DiffusionEngine + | + v +OmniRequestOutput + | + v +encode_image_base64 + | + v +ImageGenerationResponse + | + v +Client +``` + +## Call Chain (GGUF Operator Path) +``` +Pipeline.forward (Flux2/Qwen) + | + v +Transformer blocks + | + v +Flux2Attention / Flux2ParallelSelfAttention + | + v +QKVParallelLinear / ColumnParallelLinear / RowParallelLinear + | + v +LinearBase.forward + | + v +QuantMethod.apply (GGUFLinearMethod.apply) + | + v +fused_mul_mat_gguf + | + v +_fused_mul_mat_gguf (custom op) + | + v +ops.ggml_dequantize + | + v +x @ weight.T +``` + +Notes: +1. GGUF linear inputs are flattened to 2D inside `GGUFLinearMethod.apply` and reshaped back. +2. As of 2026-02-10 in this branch, `_fused_mul_mat_gguf` is forced to the dequantize path. + +## Call Chain (FP8 Operator Path) +``` +Pipeline.forward (Flux2/Qwen) + | + v +Transformer blocks + | + v +QKVParallelLinear / ColumnParallelLinear / RowParallelLinear + | + v +LinearBase.forward + | + v +QuantMethod.apply (Fp8LinearMethod.apply or Fp8OnlineLinearMethod.apply) + | + +--> apply_fp8_marlin_linear (weight-only path on older GPUs) + | + +--> W8A8BlockFp8LinearOp.apply (block quant path) + | + +--> fp8_linear.apply_weights + | + v + init_fp8_linear_kernel + | + v + FlashInferFP8ScaledMMLinearKernel / CutlassFP8ScaledMMLinearKernel / + Torch FP8 ScaledMM kernels +``` + +Notes: +1. Online FP8 differs at load time; runtime operator path matches native FP8. +2. The kernel selection is platform and capability dependent. + +## GGUF Weight Loading Path (Transformer-Only) +1. `DiffusersPipelineLoader.load_model` detects `quantization_config.method == "gguf"`. +2. `gguf_model` is resolved as one of: local file, URL, `repo/file.gguf`, or `repo:quant_type`. +3. Name mapping is applied per-architecture (Qwen-Image, Flux2-klein). +4. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. + +## FP8 Loading Path +1. Online FP8: `quantization="fp8"` or `quantization_config={"method":"fp8", "ignored_layers": [...]}`. +2. Native FP8: `quantization_config={"method":"fp8", "is_checkpoint_fp8_serialized": True}` to load an FP8-serialized checkpoint. + +## User Usage (Offline) + +### Baseline BF16 +```bash +python examples/offline_inference/text_to_image/text_to_image.py \ + --model /workspace/models/black-forest-labs/FLUX.2-klein-4B \ + --prompt "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture" \ + --height 768 \ + --width 1360 \ + --seed 42 \ + --cfg_scale 4.0 \ + --num_images_per_prompt 1 \ + --num_inference_steps 4 \ + --output outputs/flux2_klein_4b.png +``` + +### Native GGUF (Transformer Only) +```bash +python examples/offline_inference/text_to_image/text_to_image.py \ + --model /workspace/models/black-forest-labs/FLUX.2-klein-4B \ + --gguf-model "/workspace/models/unsloth/FLUX.2-klein-4B-GGUF/flux-2-klein-4b-Q8_0.gguf" \ + --quantization gguf \ + --prompt "a photo of a forest with mist swirling around the tree trunks. The word 'FLUX.2' is painted over it in big, red brush strokes with visible texture" \ + --height 768 \ + --width 1360 \ + --seed 42 \ + --cfg_scale 4.0 \ + --num_images_per_prompt 1 \ + --num_inference_steps 4 \ + --output outputs/flux2_klein_4b_gguf.png +``` + +Notes for GGUF: +1. Many GGUF repos do not ship `model_index.json` and configs. Use the base repo for `--model` and only pass the GGUF file via `--gguf-model`. +2. `gguf_model` supports local path, URL, `repo/file.gguf`, or `repo:quant_type`. + +### Online FP8 (Runtime Quantization) +```bash +python examples/offline_inference/text_to_image/text_to_image.py \ + --model Qwen/Qwen-Image \ + --quantization fp8 \ + --prompt "a cup of coffee on the table" \ + --height 1024 \ + --width 1024 +``` + +### Native FP8 (Serialized Checkpoint) +Use the Python API to pass `is_checkpoint_fp8_serialized`. +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="/path/to/fp8-checkpoint", + quantization_config={ + "method": "fp8", + "is_checkpoint_fp8_serialized": True, + }, +) + +outputs = omni.generate( + "a cup of coffee on the table", + OmniDiffusionSamplingParams(num_inference_steps=4), +) +``` + +## User Usage (Online) + +### Start Server (Online FP8) +```bash +vllm serve Qwen/Qwen-Image --omni --port 8000 --quantization fp8 +``` + +### Start Server (Native GGUF via Stage Config) +Create a stage config YAML that injects `quantization_config` into `engine_args`. +```yaml +stage_args: + - stage_id: 0 + runtime: + process: true + devices: "0" + max_batch_size: 1 + engine_args: + model_stage: diffusion + quantization_config: + method: gguf + gguf_model: /workspace/models/unsloth/FLUX.2-klein-4B-GGUF/flux-2-klein-4b-Q8_0.gguf +``` +Then run: +```bash +vllm serve /workspace/models/black-forest-labs/FLUX.2-klein-4B \ + --omni \ + --port 8000 \ + --stage-configs-path /path/to/diffusion_gguf_stage.yaml +``` + +### Online Request (Images API) +```bash +curl -X POST http://localhost:8000/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "a dragon laying over the spine of the Green Mountains of Vermont", + "size": "1024x1024", + "seed": 42, + "num_inference_steps": 4 + }' +``` + +## Validation Checklist +1. Fix the date in logs and docs for comparisons. +2. Use the same prompt, size, steps, and seed for BF16 vs GGUF/FP8 comparisons. +3. Expect accuracy differences for Q8_0 GGUF; verify mapping with F16/BF16 GGUF if needed. From 2b35f3bf1dacf0c925564dbde5a91883c48e17bb Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 10:35:12 +0800 Subject: [PATCH 08/62] support gguf fp8 add design doc 2 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/quantization/gguf_fp8_design.md | 27 +++++++------------ vllm_omni/entrypoints/cli/serve.py | 10 +++++++ vllm_omni/entrypoints/omni.py | 7 +++++ 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md index 92af0325b3..9fb20e8c53 100644 --- a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md +++ b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md @@ -233,27 +233,20 @@ outputs = omni.generate( vllm serve Qwen/Qwen-Image --omni --port 8000 --quantization fp8 ``` -### Start Server (Native GGUF via Stage Config) -Create a stage config YAML that injects `quantization_config` into `engine_args`. -```yaml -stage_args: - - stage_id: 0 - runtime: - process: true - devices: "0" - max_batch_size: 1 - engine_args: - model_stage: diffusion - quantization_config: - method: gguf - gguf_model: /workspace/models/unsloth/FLUX.2-klein-4B-GGUF/flux-2-klein-4b-Q8_0.gguf -``` -Then run: +### Start Server (Native GGUF via CLI) ```bash vllm serve /workspace/models/black-forest-labs/FLUX.2-klein-4B \ --omni \ --port 8000 \ - --stage-configs-path /path/to/diffusion_gguf_stage.yaml + --quantization-config '{"method":"gguf","gguf_model":"/workspace/models/unsloth/FLUX.2-klein-4B-GGUF/flux-2-klein-4b-Q8_0.gguf"}' +``` + +### Start Server (Native FP8 via CLI) +```bash +vllm serve /path/to/fp8-checkpoint \ + --omni \ + --port 8000 \ + --quantization-config '{"method":"fp8","is_checkpoint_fp8_serialized":true}' ``` ### Online Request (Images API) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 1be3b02f1f..4904781945 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -6,6 +6,7 @@ """ import argparse +import json import uvloop from vllm.entrypoints.cli.types import CLISubcommand @@ -161,6 +162,15 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Ring Sequence Parallelism degree for diffusion models. " "Equivalent to setting DiffusionParallelConfig.ring_degree.", ) + omni_config_group.add_argument( + "--quantization-config", + type=json.loads, + default=None, + help=( + "JSON string for diffusion quantization_config. " + "Example: '{\"method\":\"gguf\",\"gguf_model\":\"/path/to/model.gguf\"}'." + ), + ) # Cache optimization parameters omni_config_group.add_argument( diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 97357dc3b3..61e45dbdf8 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -231,6 +231,13 @@ def _initialize_stages(self, model: str, kwargs: dict[str, Any]) -> None: if lora_scale is not None: if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: cfg.engine_args.lora_scale = lora_scale + quantization_config = kwargs.get("quantization_config") + if quantization_config is not None: + if ( + not hasattr(cfg.engine_args, "quantization_config") + or cfg.engine_args.quantization_config is None + ): + cfg.engine_args.quantization_config = quantization_config except Exception as e: logger.warning("Failed to inject LoRA config for stage: %s", e) From f5c6900b482be9e994353c478e8d7987acc15ca0 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 11 Feb 2026 11:24:58 +0800 Subject: [PATCH 09/62] patch Signed-off-by: Isotr0py --- vllm_omni/diffusion/quantization/gguf.py | 60 +++++++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/quantization/gguf.py b/vllm_omni/diffusion/quantization/gguf.py index fcc4498bfe..ddedc3354e 100644 --- a/vllm_omni/diffusion/quantization/gguf.py +++ b/vllm_omni/diffusion/quantization/gguf.py @@ -2,11 +2,66 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """GGUF quantization config for diffusion transformers.""" -from vllm.model_executor.layers.quantization.gguf import GGUFConfig +import torch +import gguf +from vllm.model_executor.layers.quantization.gguf import GGUFConfig, GGUFLinearMethod, is_layer_skipped_gguf, LinearBase, QuantizeMethodBase, UnquantizedLinearMethod +from vllm import _custom_ops as ops from .base import DiffusionQuantizationConfig +def dequant_gemm_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) + return x @ weight.T + + +class DiffusionGGUFLinearMethod(GGUFLinearMethod): + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + shard_id = layer.qweight.shard_id + + if shard_id: + # dequantize shard weights respectively + shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + qweight = layer.qweight + result = [] + for idx in shard_id: + start, end, offset = layer.qweight.shard_offset_map[idx] + qweight_type = layer.qweight_type.shard_weight_type[idx] + result.append( + dequant_gemm_gguf( + x, qweight[start:end, :offset].contiguous(), qweight_type + ) + ) + out = torch.cat(result, axis=-1) + else: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + out = dequant_gemm_gguf(x, qweight, qweight_type) + if bias is not None: + out.add_(bias) + return out + + +class _GGUFConfig(GGUFConfig): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase": + if isinstance(layer, LinearBase): + if is_layer_skipped_gguf( + prefix, self.unquantized_modules, self.packed_modules_mapping + ): + return UnquantizedLinearMethod() + return DiffusionGGUFLinearMethod(self) + return None + + class DiffusionGgufConfig(DiffusionQuantizationConfig): """GGUF quantization config for diffusion transformers. @@ -29,4 +84,5 @@ def __init__( ) -> None: self.gguf_model = gguf_model self.unquantized_modules = unquantized_modules or [] - self._vllm_config = GGUFConfig(unquantized_modules=self.unquantized_modules) + + self._vllm_config = _GGUFConfig(unquantized_modules=self.unquantized_modules) From 6fc0f8c65b7fce93794f760e6f99f324be64a5a3 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 15:25:35 +0800 Subject: [PATCH 10/62] support gguf fp8 6 Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/diffusers_loader.py | 191 +++++++++--------- 1 file changed, 99 insertions(+), 92 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 4f8bbdf7ed..eea7711bce 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -6,7 +6,7 @@ import time from collections.abc import Generator, Iterable from pathlib import Path -from typing import Callable, cast +from typing import cast import torch from torch import nn @@ -345,93 +345,11 @@ def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str "raw URL, /.gguf, or :)" ) - def _get_gguf_name_mapper(self, od_config: OmniDiffusionConfig) -> Callable[[str], str | None]: - model_class = od_config.model_class_name - if model_class in { - "QwenImagePipeline", - "QwenImageEditPipeline", - "QwenImageEditPlusPipeline", - "QwenImageLayeredPipeline", - }: - return lambda name: name - if model_class == "Flux2KleinPipeline": - return self._map_flux2_klein_gguf_name - raise ValueError(f"GGUF mapping is not implemented for model_class_name={model_class!r}") - - @staticmethod - def _map_flux2_klein_gguf_name(name: str) -> str | None: - if name.startswith("double_stream_modulation_img.lin."): - return name.replace("double_stream_modulation_img.lin.", "double_stream_modulation_img.linear.", 1) - if name.startswith("double_stream_modulation_txt.lin."): - return name.replace("double_stream_modulation_txt.lin.", "double_stream_modulation_txt.linear.", 1) - if name.startswith("single_stream_modulation.lin."): - return name.replace("single_stream_modulation.lin.", "single_stream_modulation.linear.", 1) - if name.startswith("img_in."): - return name.replace("img_in.", "x_embedder.", 1) - if name.startswith("txt_in."): - return name.replace("txt_in.", "context_embedder.", 1) - if name.startswith("time_in.in_layer."): - return name.replace( - "time_in.in_layer.", - "time_guidance_embed.timestep_embedder.linear_1.", - 1, - ) - if name.startswith("time_in.out_layer."): - return name.replace( - "time_in.out_layer.", - "time_guidance_embed.timestep_embedder.linear_2.", - 1, - ) - if name.startswith("final_layer.adaLN_modulation.1."): - return name.replace("final_layer.adaLN_modulation.1.", "norm_out.linear.", 1) - if name.startswith("final_layer.linear."): - return name.replace("final_layer.linear.", "proj_out.", 1) - - if name.startswith("double_blocks."): - name = name.replace("double_blocks.", "transformer_blocks.", 1) - if ".img_attn.qkv." in name: - return name.replace(".img_attn.qkv.", ".attn.to_qkv.", 1) - if ".img_attn.proj." in name: - return name.replace(".img_attn.proj.", ".attn.to_out.0.", 1) - if name.endswith(".img_attn.norm.query_norm.scale"): - return name.replace(".img_attn.norm.query_norm.scale", ".attn.norm_q.weight", 1) - if name.endswith(".img_attn.norm.key_norm.scale"): - return name.replace(".img_attn.norm.key_norm.scale", ".attn.norm_k.weight", 1) - if ".txt_attn.qkv." in name: - return name.replace(".txt_attn.qkv.", ".attn.add_kv_proj.", 1) - if ".txt_attn.proj." in name: - return name.replace(".txt_attn.proj.", ".attn.to_add_out.", 1) - if name.endswith(".txt_attn.norm.query_norm.scale"): - return name.replace(".txt_attn.norm.query_norm.scale", ".attn.norm_added_q.weight", 1) - if name.endswith(".txt_attn.norm.key_norm.scale"): - return name.replace(".txt_attn.norm.key_norm.scale", ".attn.norm_added_k.weight", 1) - if ".img_mlp.0." in name: - return name.replace(".img_mlp.0.", ".ff.linear_in.", 1) - if ".img_mlp.2." in name: - return name.replace(".img_mlp.2.", ".ff.linear_out.", 1) - if ".txt_mlp.0." in name: - return name.replace(".txt_mlp.0.", ".ff_context.linear_in.", 1) - if ".txt_mlp.2." in name: - return name.replace(".txt_mlp.2.", ".ff_context.linear_out.", 1) - return None - - if name.startswith("single_blocks."): - name = name.replace("single_blocks.", "single_transformer_blocks.", 1) - if ".linear1." in name: - return name.replace(".linear1.", ".attn.to_qkv_mlp_proj.", 1) - if ".linear2." in name: - return name.replace(".linear2.", ".attn.to_out.", 1) - if name.endswith(".norm.query_norm.scale"): - return name.replace(".norm.query_norm.scale", ".attn.norm_q.weight", 1) - if name.endswith(".norm.key_norm.scale"): - return name.replace(".norm.key_norm.scale", ".attn.norm_k.weight", 1) - return None - - return None - def _build_gguf_name_map( self, gguf_file: str, + model: nn.Module, + source: "ComponentSource", od_config: OmniDiffusionConfig, ) -> dict[str, str]: try: @@ -441,14 +359,102 @@ def _build_gguf_name_map( "GGUF support requires the 'gguf' package to be installed." ) from exc - mapper = self._get_gguf_name_mapper(od_config) + def resolve_model_type() -> str: + cfg = od_config.tf_model_config + model_type = None + if cfg is not None: + model_type = cfg.get("model_type") + if model_type: + return model_type + model_class = od_config.model_class_name or "" + if model_class.startswith("QwenImage"): + return "qwen_image" + if model_class.startswith("Flux2"): + return "flux" + raise ValueError("Cannot infer gguf model_type for diffusion model.") + + def resolve_arch(model_type: str): + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + return key + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + + def resolve_num_layers(target_module: nn.Module) -> int: + if hasattr(target_module, "transformer_blocks"): + return len(getattr(target_module, "transformer_blocks")) + if hasattr(target_module, "double_blocks"): + return len(getattr(target_module, "double_blocks")) + cfg = od_config.tf_model_config + if cfg is not None: + for key in ("num_hidden_layers", "num_layers", "n_layers"): + value = cfg.get(key) + if isinstance(value, int) and value > 0: + return value + raise ValueError("Cannot infer gguf num_layers for diffusion model.") + + def get_target_module(root: nn.Module, prefix: str) -> nn.Module: + if not prefix: + return root + prefix = prefix.rstrip(".") + if hasattr(root, "get_submodule"): + return root.get_submodule(prefix) + current = root + for part in prefix.split("."): + current = getattr(current, part) + return current + + def split_name(name: str) -> tuple[str, str]: + if name.endswith("_weight"): + return name[:-7], "weight" + if "." in name: + base, suffix = name.rsplit(".", 1) + return base, suffix + return name, "" + reader = gguf.GGUFReader(gguf_file) + gguf_tensor_names = {tensor.name for tensor in reader.tensors} + + model_type = resolve_model_type() + arch = resolve_arch(model_type) + target_module = get_target_module(model, source.prefix) + num_layers = resolve_num_layers(target_module) + name_map = gguf.get_tensor_name_map(arch, num_layers) + gguf_to_model_map: dict[str, str] = {} - for tensor in reader.tensors: - mapped = mapper(tensor.name) - if mapped is None: + for name, _ in target_module.named_parameters(): + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: continue - gguf_to_model_map[tensor.name] = mapped + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + + for name, _ in target_module.named_buffers(): + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: + continue + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + if not gguf_to_model_map: raise RuntimeError( f"No GGUF tensors were mapped for model_class_name={od_config.model_class_name!r}." @@ -458,6 +464,7 @@ def _build_gguf_name_map( def _get_gguf_weights_iterator( self, source: "ComponentSource", + model: nn.Module, od_config: OmniDiffusionConfig, ) -> Generator[tuple[str, torch.Tensor], None, None]: quant_config = od_config.quantization_config @@ -465,7 +472,7 @@ def _get_gguf_weights_iterator( if gguf_model is None: raise ValueError("GGUF quantization requires quantization_config.gguf_model") gguf_file = self._resolve_gguf_model_path(gguf_model, od_config.revision) - gguf_name_map = self._build_gguf_name_map(gguf_file, od_config) + gguf_name_map = self._build_gguf_name_map(gguf_file, model, source, od_config) weights_iter = gguf_quant_weights_iterator(gguf_file, gguf_name_map) return ((source.prefix + name, tensor) for (name, tensor) in weights_iter) @@ -479,7 +486,7 @@ def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConf for source in sources: if self._is_transformer_source(source): - loaded |= model.load_weights(self._get_gguf_weights_iterator(source, od_config)) + loaded |= model.load_weights(self._get_gguf_weights_iterator(source, model, od_config)) # Load any remaining float weights (e.g., non-quantized layers) # from the base HF checkpoint while skipping already-loaded names. From 769a5fb19ac167ffb023d784d828551286b609d2 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 15:54:31 +0800 Subject: [PATCH 11/62] support gguf fp8 7 Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/diffusers_loader.py | 122 +--------- .../model_loader/gguf_adapters/__init__.py | 15 ++ .../model_loader/gguf_adapters/base.py | 137 +++++++++++ .../model_loader/gguf_adapters/flux2.py | 224 ++++++++++++++++++ 4 files changed, 379 insertions(+), 119 deletions(-) create mode 100644 vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py create mode 100644 vllm_omni/diffusion/model_loader/gguf_adapters/base.py create mode 100644 vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index eea7711bce..4928993e62 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -21,13 +21,13 @@ download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, - gguf_quant_weights_iterator, maybe_download_from_modelscope, safetensors_weights_iterator, ) from vllm.utils.torch_utils import set_default_torch_dtype from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter from vllm_omni.diffusion.registry import initialize_model logger = init_logger(__name__) @@ -345,122 +345,6 @@ def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str "raw URL, /.gguf, or :)" ) - def _build_gguf_name_map( - self, - gguf_file: str, - model: nn.Module, - source: "ComponentSource", - od_config: OmniDiffusionConfig, - ) -> dict[str, str]: - try: - import gguf # type: ignore - except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError( - "GGUF support requires the 'gguf' package to be installed." - ) from exc - - def resolve_model_type() -> str: - cfg = od_config.tf_model_config - model_type = None - if cfg is not None: - model_type = cfg.get("model_type") - if model_type: - return model_type - model_class = od_config.model_class_name or "" - if model_class.startswith("QwenImage"): - return "qwen_image" - if model_class.startswith("Flux2"): - return "flux" - raise ValueError("Cannot infer gguf model_type for diffusion model.") - - def resolve_arch(model_type: str): - for key, value in gguf.MODEL_ARCH_NAMES.items(): - if value == model_type: - return key - raise RuntimeError(f"Unknown gguf model_type: {model_type}") - - def resolve_num_layers(target_module: nn.Module) -> int: - if hasattr(target_module, "transformer_blocks"): - return len(getattr(target_module, "transformer_blocks")) - if hasattr(target_module, "double_blocks"): - return len(getattr(target_module, "double_blocks")) - cfg = od_config.tf_model_config - if cfg is not None: - for key in ("num_hidden_layers", "num_layers", "n_layers"): - value = cfg.get(key) - if isinstance(value, int) and value > 0: - return value - raise ValueError("Cannot infer gguf num_layers for diffusion model.") - - def get_target_module(root: nn.Module, prefix: str) -> nn.Module: - if not prefix: - return root - prefix = prefix.rstrip(".") - if hasattr(root, "get_submodule"): - return root.get_submodule(prefix) - current = root - for part in prefix.split("."): - current = getattr(current, part) - return current - - def split_name(name: str) -> tuple[str, str]: - if name.endswith("_weight"): - return name[:-7], "weight" - if "." in name: - base, suffix = name.rsplit(".", 1) - return base, suffix - return name, "" - - reader = gguf.GGUFReader(gguf_file) - gguf_tensor_names = {tensor.name for tensor in reader.tensors} - - model_type = resolve_model_type() - arch = resolve_arch(model_type) - target_module = get_target_module(model, source.prefix) - num_layers = resolve_num_layers(target_module) - name_map = gguf.get_tensor_name_map(arch, num_layers) - - gguf_to_model_map: dict[str, str] = {} - for name, _ in target_module.named_parameters(): - base_name, suffix = split_name(name) - gguf_base = name_map.get_name(base_name) - if gguf_base is None: - continue - candidates = [] - if suffix: - candidates.append(f"{gguf_base}.{suffix}") - if suffix == "weight": - candidates.append(f"{gguf_base}.scale") - else: - candidates.append(gguf_base) - gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) - if gguf_name is None: - continue - gguf_to_model_map[gguf_name] = name - - for name, _ in target_module.named_buffers(): - base_name, suffix = split_name(name) - gguf_base = name_map.get_name(base_name) - if gguf_base is None: - continue - candidates = [] - if suffix: - candidates.append(f"{gguf_base}.{suffix}") - if suffix == "weight": - candidates.append(f"{gguf_base}.scale") - else: - candidates.append(gguf_base) - gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) - if gguf_name is None: - continue - gguf_to_model_map[gguf_name] = name - - if not gguf_to_model_map: - raise RuntimeError( - f"No GGUF tensors were mapped for model_class_name={od_config.model_class_name!r}." - ) - return gguf_to_model_map - def _get_gguf_weights_iterator( self, source: "ComponentSource", @@ -472,8 +356,8 @@ def _get_gguf_weights_iterator( if gguf_model is None: raise ValueError("GGUF quantization requires quantization_config.gguf_model") gguf_file = self._resolve_gguf_model_path(gguf_model, od_config.revision) - gguf_name_map = self._build_gguf_name_map(gguf_file, model, source, od_config) - weights_iter = gguf_quant_weights_iterator(gguf_file, gguf_name_map) + adapter = get_gguf_adapter(gguf_file, model, source, od_config) + weights_iter = adapter.weights_iterator() return ((source.prefix + name, tensor) for (name, tensor) in weights_iter) def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConfig) -> set[str]: diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py new file mode 100644 index 0000000000..5383b6e15d --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from .base import GGUFAdapter +from .flux2 import Flux2GGUFAdapter + + +def get_gguf_adapter(gguf_file: str, model, source, od_config) -> GGUFAdapter: + for adapter_cls in (Flux2GGUFAdapter,): + if adapter_cls.is_compatible(od_config, model, source): + return adapter_cls(gguf_file, model, source, od_config) + return GGUFAdapter(gguf_file, model, source, od_config) + + +__all__ = ["GGUFAdapter", "Flux2GGUFAdapter", "get_gguf_adapter"] diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py new file mode 100644 index 0000000000..2eae17a1f4 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections.abc import Generator + +import torch + +from vllm.model_executor.model_loader.weight_utils import gguf_quant_weights_iterator + + +class GGUFAdapter: + """Default GGUF adapter using gguf-py's tensor name mapping.""" + + def __init__(self, gguf_file: str, model: torch.nn.Module, source, od_config) -> None: + self.gguf_file = gguf_file + self.model = model + self.source = source + self.od_config = od_config + + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + # Default adapter matches any model. + return True + + def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: + name_map = self._build_gguf_name_map() + return gguf_quant_weights_iterator(self.gguf_file, name_map) + + def _build_gguf_name_map(self) -> dict[str, str]: + try: + import gguf # type: ignore + except Exception as exc: # pragma: no cover - dependency error + raise RuntimeError( + "GGUF support requires the 'gguf' package to be installed." + ) from exc + + def resolve_model_type() -> str: + cfg = self.od_config.tf_model_config + model_type = None + if cfg is not None: + model_type = cfg.get("model_type") + if model_type: + return model_type + model_class = self.od_config.model_class_name or "" + if model_class.startswith("QwenImage"): + return "qwen_image" + if model_class.startswith("Flux2"): + return "flux" + raise ValueError("Cannot infer gguf model_type for diffusion model.") + + def resolve_arch(model_type: str): + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + return key + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + + def resolve_num_layers(target_module: torch.nn.Module) -> int: + if hasattr(target_module, "transformer_blocks"): + return len(getattr(target_module, "transformer_blocks")) + if hasattr(target_module, "double_blocks"): + return len(getattr(target_module, "double_blocks")) + cfg = self.od_config.tf_model_config + if cfg is not None: + for key in ("num_hidden_layers", "num_layers", "n_layers"): + value = cfg.get(key) + if isinstance(value, int) and value > 0: + return value + raise ValueError("Cannot infer gguf num_layers for diffusion model.") + + def get_target_module(root: torch.nn.Module, prefix: str) -> torch.nn.Module: + if not prefix: + return root + prefix = prefix.rstrip(".") + if hasattr(root, "get_submodule"): + return root.get_submodule(prefix) + current = root + for part in prefix.split("."): + current = getattr(current, part) + return current + + def split_name(name: str) -> tuple[str, str]: + if name.endswith("_weight"): + return name[:-7], "weight" + if "." in name: + base, suffix = name.rsplit(".", 1) + return base, suffix + return name, "" + + reader = gguf.GGUFReader(self.gguf_file) + gguf_tensor_names = {tensor.name for tensor in reader.tensors} + + model_type = resolve_model_type() + arch = resolve_arch(model_type) + target_module = get_target_module(self.model, self.source.prefix) + num_layers = resolve_num_layers(target_module) + name_map = gguf.get_tensor_name_map(arch, num_layers) + + gguf_to_model_map: dict[str, str] = {} + for name, _ in target_module.named_parameters(): + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: + continue + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + + for name, _ in target_module.named_buffers(): + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: + continue + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + + if not gguf_to_model_map: + raise RuntimeError( + f"No GGUF tensors were mapped for model_class_name={self.od_config.model_class_name!r}." + ) + return gguf_to_model_map diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py new file mode 100644 index 0000000000..6cacb1e095 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass + +import numpy as np +import torch + +from .base import GGUFAdapter + + +@dataclass +class _MappedTensor: + name: str + tensor + tensor_type + row_slice: slice | None = None + swap_scale_shift: bool = False + + +class Flux2GGUFAdapter(GGUFAdapter): + """GGUF adapter for Flux2 models with qkv splitting and adaLN swap.""" + + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + model_class = od_config.model_class_name or "" + if model_class.startswith("Flux2"): + return True + cfg = od_config.tf_model_config + if cfg is not None: + model_type = str(cfg.get("model_type", "")).lower() + if model_type.startswith("flux2"): + return True + # Fallback: Flux2 transformer has single_transformer_blocks + prefix = getattr(source, "prefix", "") + target = model.get_submodule(prefix.rstrip(".")) if prefix else model + return hasattr(target, "single_transformer_blocks") + + def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: + try: + import gguf # type: ignore + except Exception as exc: # pragma: no cover - dependency error + raise RuntimeError( + "GGUF support requires the 'gguf' package to be installed." + ) from exc + + reader = gguf.GGUFReader(self.gguf_file) + allowed_names = self._build_allowed_names() + mapped: list[_MappedTensor] = [] + + for tensor in reader.tensors: + for mapped_tensor in self._map_tensor_name(tensor): + if mapped_tensor.name not in allowed_names: + continue + mapped.append(mapped_tensor) + + if not mapped: + raise RuntimeError( + "No GGUF tensors were mapped for Flux2 GGUF loader. " + "Please verify the GGUF file and model structure." + ) + + for item in mapped: + weight_type = item.tensor_type + if weight_type.name not in ("F32", "BF16", "F16"): + weight_type_name = item.name.replace("weight", "qweight_type") + yield weight_type_name, torch.tensor(weight_type) + + for item in mapped: + weight = item.tensor.data + if item.row_slice is not None: + weight = weight[item.row_slice] + weight_type = item.tensor_type + if weight_type.name not in ("F32", "BF16", "F16"): + name = item.name.replace("weight", "qweight") + else: + name = item.name + + if weight_type.name == "BF16" and weight.dtype == np.uint8: + weight = weight.view(np.uint16) + if reader.byte_order == "S": + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) + + if item.swap_scale_shift: + shift, scale = param.chunk(2, dim=0) + param = torch.cat([scale, shift], dim=0) + + yield name, param + + def _build_allowed_names(self) -> set[str]: + prefix = getattr(self.source, "prefix", "") + target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + allowed = {name for name, _ in target.named_parameters()} + allowed.update(name for name, _ in target.named_buffers()) + + virtual_names = set() + for name in allowed: + if ".to_qkv." in name: + virtual_names.add(name.replace(".to_qkv.", ".to_q.")) + virtual_names.add(name.replace(".to_qkv.", ".to_k.")) + virtual_names.add(name.replace(".to_qkv.", ".to_v.")) + if ".add_kv_proj." in name: + virtual_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) + allowed.update(virtual_names) + return allowed + + def _map_tensor_name(self, tensor) -> list[_MappedTensor]: + name = tensor.name + + if name.startswith("double_blocks."): + return self._map_double_blocks(tensor) + if name.startswith("single_blocks."): + return self._map_single_blocks(tensor) + if name.startswith("final_layer.adaLN_modulation.1") and name.endswith(".weight"): + return [ + _MappedTensor( + name="norm_out.linear.weight", + tensor=tensor, + tensor_type=tensor.tensor_type, + swap_scale_shift=True, + ) + ] + + for src, dst in _FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items(): + name = name.replace(src, dst) + + return [ + _MappedTensor( + name=name, + tensor=tensor, + tensor_type=tensor.tensor_type, + ) + ] + + def _map_double_blocks(self, tensor) -> list[_MappedTensor]: + name = tensor.name + parts = name.split(".") + block_idx = parts[1] + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + if param_type == "scale": + param_type = "weight" + + if "qkv" in within_block_name: + if "img_attn" in within_block_name: + q_name = f"transformer_blocks.{block_idx}.attn.to_q.{param_type}" + k_name = f"transformer_blocks.{block_idx}.attn.to_k.{param_type}" + v_name = f"transformer_blocks.{block_idx}.attn.to_v.{param_type}" + elif "txt_attn" in within_block_name: + q_name = f"transformer_blocks.{block_idx}.attn.add_q_proj.{param_type}" + k_name = f"transformer_blocks.{block_idx}.attn.add_k_proj.{param_type}" + v_name = f"transformer_blocks.{block_idx}.attn.add_v_proj.{param_type}" + else: + return [] + + weight = tensor.data + dim0 = weight.shape[0] + split = dim0 // 3 + return [ + _MappedTensor(q_name, tensor, tensor.tensor_type, slice(0, split)), + _MappedTensor(k_name, tensor, tensor.tensor_type, slice(split, 2 * split)), + _MappedTensor(v_name, tensor, tensor.tensor_type, slice(2 * split, 3 * split)), + ] + + mapped_name = _FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP.get(within_block_name) + if mapped_name is None: + return [] + target = f"transformer_blocks.{block_idx}.{mapped_name}.{param_type}" + return [_MappedTensor(target, tensor, tensor.tensor_type)] + + def _map_single_blocks(self, tensor) -> list[_MappedTensor]: + name = tensor.name + parts = name.split(".") + block_idx = parts[1] + within_block_name = ".".join(parts[2:-1]) + param_type = parts[-1] + if param_type == "scale": + param_type = "weight" + + mapped_name = _FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP.get(within_block_name) + if mapped_name is None: + return [] + target = f"single_transformer_blocks.{block_idx}.{mapped_name}.{param_type}" + return [_MappedTensor(target, tensor, tensor.tensor_type)] + + +_FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { + "img_in": "x_embedder", + "txt_in": "context_embedder", + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + "final_layer.linear": "proj_out", +} + +_FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { + "img_attn.norm.query_norm": "attn.norm_q", + "img_attn.norm.key_norm": "attn.norm_k", + "img_attn.proj": "attn.to_out.0", + "img_mlp.0": "ff.linear_in", + "img_mlp.2": "ff.linear_out", + "txt_attn.norm.query_norm": "attn.norm_added_q", + "txt_attn.norm.key_norm": "attn.norm_added_k", + "txt_attn.proj": "attn.to_add_out", + "txt_mlp.0": "ff_context.linear_in", + "txt_mlp.2": "ff_context.linear_out", +} + +_FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { + "linear1": "attn.to_qkv_mlp_proj", + "linear2": "attn.to_out", + "norm.query_norm": "attn.norm_q", + "norm.key_norm": "attn.norm_k", +} From ce19e1b91ab852803e00bcfbf771655837aa4cf5 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 15:57:21 +0800 Subject: [PATCH 12/62] support gguf fp8 8 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py index 6cacb1e095..9a8a6e6f6f 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py @@ -3,6 +3,7 @@ from collections.abc import Generator from dataclasses import dataclass +from typing import Any import numpy as np import torch @@ -13,8 +14,8 @@ @dataclass class _MappedTensor: name: str - tensor - tensor_type + tensor: Any + tensor_type: Any row_slice: slice | None = None swap_scale_shift: bool = False From f535ca1324e5a0d1ceba377c777ffa379c2a89b2 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 16:04:18 +0800 Subject: [PATCH 13/62] support gguf fp8 9 Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/gguf_adapters/flux2.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py index 9a8a6e6f6f..cde0accd25 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py @@ -48,6 +48,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: reader = gguf.GGUFReader(self.gguf_file) allowed_names = self._build_allowed_names() + param_names = self._build_param_names() mapped: list[_MappedTensor] = [] for tensor in reader.tensors: @@ -63,17 +64,25 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: ) for item in mapped: - weight_type = item.tensor_type - if weight_type.name not in ("F32", "BF16", "F16"): - weight_type_name = item.name.replace("weight", "qweight_type") - yield weight_type_name, torch.tensor(weight_type) + is_linear_weight = ( + item.name.endswith(".weight") + and item.name.replace(".weight", ".qweight") in param_names + ) + if not is_linear_weight: + continue + weight_type_name = item.name.replace("weight", "qweight_type") + yield weight_type_name, torch.tensor(item.tensor_type) for item in mapped: weight = item.tensor.data if item.row_slice is not None: weight = weight[item.row_slice] weight_type = item.tensor_type - if weight_type.name not in ("F32", "BF16", "F16"): + is_linear_weight = ( + item.name.endswith(".weight") + and item.name.replace(".weight", ".qweight") in param_names + ) + if is_linear_weight: name = item.name.replace("weight", "qweight") else: name = item.name @@ -97,6 +106,11 @@ def _build_allowed_names(self) -> set[str]: target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model allowed = {name for name, _ in target.named_parameters()} allowed.update(name for name, _ in target.named_buffers()) + for name in list(allowed): + if name.endswith(".qweight"): + allowed.add(name.replace(".qweight", ".weight")) + elif name.endswith(".qweight_type"): + allowed.add(name.replace(".qweight_type", ".weight")) virtual_names = set() for name in allowed: @@ -111,6 +125,11 @@ def _build_allowed_names(self) -> set[str]: allowed.update(virtual_names) return allowed + def _build_param_names(self) -> set[str]: + prefix = getattr(self.source, "prefix", "") + target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + return {name for name, _ in target.named_parameters()} + def _map_tensor_name(self, tensor) -> list[_MappedTensor]: name = tensor.name From 11ca22fbf84c7dce0c81ababadf19fa87311da1d Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 16:08:36 +0800 Subject: [PATCH 14/62] support gguf fp8 10 Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/gguf_adapters/flux2.py | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py index cde0accd25..c6dcef13e3 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py @@ -53,7 +53,11 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: for tensor in reader.tensors: for mapped_tensor in self._map_tensor_name(tensor): - if mapped_tensor.name not in allowed_names: + if ( + mapped_tensor.name not in allowed_names + and self._resolve_linear_qweight(mapped_tensor.name, param_names) + is None + ): continue mapped.append(mapped_tensor) @@ -64,13 +68,11 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: ) for item in mapped: - is_linear_weight = ( - item.name.endswith(".weight") - and item.name.replace(".weight", ".qweight") in param_names - ) + linear_qweight = self._resolve_linear_qweight(item.name, param_names) + is_linear_weight = linear_qweight is not None if not is_linear_weight: continue - weight_type_name = item.name.replace("weight", "qweight_type") + weight_type_name = linear_qweight.replace("qweight", "qweight_type") yield weight_type_name, torch.tensor(item.tensor_type) for item in mapped: @@ -78,12 +80,10 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: if item.row_slice is not None: weight = weight[item.row_slice] weight_type = item.tensor_type - is_linear_weight = ( - item.name.endswith(".weight") - and item.name.replace(".weight", ".qweight") in param_names - ) + linear_qweight = self._resolve_linear_qweight(item.name, param_names) + is_linear_weight = linear_qweight is not None if is_linear_weight: - name = item.name.replace("weight", "qweight") + name = linear_qweight else: name = item.name @@ -130,6 +130,26 @@ def _build_param_names(self) -> set[str]: target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model return {name for name, _ in target.named_parameters()} + def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: + if not name.endswith(".weight"): + return None + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + for src, dst in ( + (".to_q.", ".to_qkv."), + (".to_k.", ".to_qkv."), + (".to_v.", ".to_qkv."), + (".add_q_proj.", ".add_kv_proj."), + (".add_k_proj.", ".add_kv_proj."), + (".add_v_proj.", ".add_kv_proj."), + ): + if src in name: + candidate = name.replace(src, dst).replace(".weight", ".qweight") + if candidate in param_names: + return candidate + return None + def _map_tensor_name(self, tensor) -> list[_MappedTensor]: name = tensor.name From b2916b1354ad35eb7f9a52c28e0430d23bd47ed1 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 16:14:17 +0800 Subject: [PATCH 15/62] support gguf fp8 11 Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/gguf_adapters/flux2.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py index c6dcef13e3..12b7cc94f8 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py @@ -133,21 +133,20 @@ def _build_param_names(self) -> set[str]: def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: if not name.endswith(".weight"): return None + # Keep QKV shard names so load_weights can attach shard_id correctly. + for shard_token in ( + ".to_q.", + ".to_k.", + ".to_v.", + ".add_q_proj.", + ".add_k_proj.", + ".add_v_proj.", + ): + if shard_token in name: + return name.replace(".weight", ".qweight") candidate = name.replace(".weight", ".qweight") if candidate in param_names: return candidate - for src, dst in ( - (".to_q.", ".to_qkv."), - (".to_k.", ".to_qkv."), - (".to_v.", ".to_qkv."), - (".add_q_proj.", ".add_kv_proj."), - (".add_k_proj.", ".add_kv_proj."), - (".add_v_proj.", ".add_kv_proj."), - ): - if src in name: - candidate = name.replace(src, dst).replace(".weight", ".qweight") - if candidate in param_names: - return candidate return None def _map_tensor_name(self, tensor) -> list[_MappedTensor]: From 929f1f7620d4522ca686ea5726f3bd471e2965a3 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 16:20:55 +0800 Subject: [PATCH 16/62] support gguf fp8 12 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py | 6 +++--- .../model_loader/gguf_adapters/{flux2.py => flux2_klein.py} | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) rename vllm_omni/diffusion/model_loader/gguf_adapters/{flux2.py => flux2_klein.py} (98%) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py index 5383b6e15d..737132ead6 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -2,14 +2,14 @@ from __future__ import annotations from .base import GGUFAdapter -from .flux2 import Flux2GGUFAdapter +from .flux2_klein import Flux2KleinGGUFAdapter def get_gguf_adapter(gguf_file: str, model, source, od_config) -> GGUFAdapter: - for adapter_cls in (Flux2GGUFAdapter,): + for adapter_cls in (Flux2KleinGGUFAdapter,): if adapter_cls.is_compatible(od_config, model, source): return adapter_cls(gguf_file, model, source, od_config) return GGUFAdapter(gguf_file, model, source, od_config) -__all__ = ["GGUFAdapter", "Flux2GGUFAdapter", "get_gguf_adapter"] +__all__ = ["GGUFAdapter", "Flux2KleinGGUFAdapter", "get_gguf_adapter"] diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py similarity index 98% rename from vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py rename to vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index 12b7cc94f8..0bf0ff2214 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -20,13 +20,13 @@ class _MappedTensor: swap_scale_shift: bool = False -class Flux2GGUFAdapter(GGUFAdapter): - """GGUF adapter for Flux2 models with qkv splitting and adaLN swap.""" +class Flux2KleinGGUFAdapter(GGUFAdapter): + """GGUF adapter for Flux2-Klein models with qkv splitting and adaLN swap.""" @staticmethod def is_compatible(od_config, model: torch.nn.Module, source) -> bool: model_class = od_config.model_class_name or "" - if model_class.startswith("Flux2"): + if model_class.startswith("Flux2Klein"): return True cfg = od_config.tf_model_config if cfg is not None: From a4cefac7251f864f161f5d8a730825ce1c3ed5df Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 11 Feb 2026 17:22:38 +0800 Subject: [PATCH 17/62] support gguf fp8 add design doc 3 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/quantization/gguf_fp8_design.md | 59 ++++++++++++++++++- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md index 9fb20e8c53..8669ca7a60 100644 --- a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md +++ b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md @@ -1,6 +1,6 @@ # Diffusion Quantization Design: Native GGUF + FP8 (Online and Native) -Date: 2026-02-10 +Date: 2026-02-11 ## Goals 1. Reuse vLLM quantization configs and weight loaders as much as possible. @@ -16,6 +16,7 @@ Date: 2026-02-10 1. `OmniDiffusionConfig` accepts `quantization` or `quantization_config`. 2. Diffusion quantization wrappers (`DiffusionGgufConfig`, `DiffusionFp8Config`) produce vLLM `QuantizationConfig` objects for linear layers. 3. `DiffusersPipelineLoader` branches on quantization method and loads either HF weights or GGUF weights for the transformer. +4. GGUF transformer loading is routed through model-specific adapters (e.g., Flux2Klein). 4. vLLM GGUF path uses `GGUFConfig` and `GGUFLinearMethod` for matmul; FP8 uses `Fp8Config` (online) or `is_checkpoint_fp8_serialized` for native FP8. ## Call Chain (Offline) @@ -153,9 +154,63 @@ Notes: ## GGUF Weight Loading Path (Transformer-Only) 1. `DiffusersPipelineLoader.load_model` detects `quantization_config.method == "gguf"`. 2. `gguf_model` is resolved as one of: local file, URL, `repo/file.gguf`, or `repo:quant_type`. -3. Name mapping is applied per-architecture (Qwen-Image, Flux2-klein). +3. GGUF weights are routed through adapters in `vllm_omni/diffusion/model_loader/gguf_adapters/`. +4. Name mapping is applied per-architecture (Qwen-Image, Flux2Klein). 4. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. +## GGUF Adapter Design +1. `GGUFAdapter` (base) implements default gguf-py tensor name mapping. +2. `Flux2KleinGGUFAdapter` implements Flux2-Klein remapping + qkv split + adaLN swap. +3. `get_gguf_adapter(...)` selects the adapter by model class/config and returns an iterator of `(name, tensor)`. + +Adapter paths: +- Base: `vllm_omni/diffusion/model_loader/gguf_adapters/base.py` +- Flux2-Klein: `vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py` + +## Flux2-Klein GGUF Mapping (Key Rules) +1. **Core rename (diffusers-compatible)**: + - `img_in` -> `x_embedder` + - `txt_in` -> `context_embedder` + - `time_in.*` -> `time_guidance_embed.timestep_embedder.*` + - `guidance_in.*` -> `time_guidance_embed.guidance_embedder.*` + - `double_stream_modulation_*` -> `double_stream_modulation_*.linear` + - `single_stream_modulation.lin` -> `single_stream_modulation.linear` + - `final_layer.linear` -> `proj_out` +2. **Double blocks (img/txt)**: + - `double_blocks.{i}.img_attn.qkv.weight` + -> `transformer_blocks.{i}.attn.to_q/to_k/to_v` + - `double_blocks.{i}.txt_attn.qkv.weight` + -> `transformer_blocks.{i}.attn.add_q_proj/add_k_proj/add_v_proj` + - Other mappings: + - `img_attn.norm.query_norm` -> `attn.norm_q` + - `img_attn.norm.key_norm` -> `attn.norm_k` + - `img_attn.proj` -> `attn.to_out.0` + - `img_mlp.0` -> `ff.linear_in` + - `img_mlp.2` -> `ff.linear_out` + - `txt_attn.norm.query_norm` -> `attn.norm_added_q` + - `txt_attn.norm.key_norm` -> `attn.norm_added_k` + - `txt_attn.proj` -> `attn.to_add_out` + - `txt_mlp.0` -> `ff_context.linear_in` + - `txt_mlp.2` -> `ff_context.linear_out` +3. **Single blocks**: + - `single_blocks.{i}.linear1` -> `single_transformer_blocks.{i}.attn.to_qkv_mlp_proj` + - `single_blocks.{i}.linear2` -> `single_transformer_blocks.{i}.attn.to_out` + - `single_blocks.{i}.norm.query_norm` -> `single_transformer_blocks.{i}.attn.norm_q` + - `single_blocks.{i}.norm.key_norm` -> `single_transformer_blocks.{i}.attn.norm_k` +4. **AdaLN swap**: + - `final_layer.adaLN_modulation.1.weight` -> `norm_out.linear.weight` with (shift, scale) swapped. + +## Flux2-Klein GGUF Loader Logic +1. **Iterator flow**: + - Read GGUF tensors via `gguf.GGUFReader`. + - Apply Flux2-Klein mapping rules to produce diffusers-style names. + - For QKV tensors, split along dim0 into Q/K/V shards. +2. **Linear weights go to `qweight`** (both quantized and BF16/F16): + - Always emit `qweight_type` for linear weights. + - Use shard names (`to_q.qweight`, `to_k.qweight`, `to_v.qweight`) so vLLM can reassemble into `to_qkv.qweight`. +3. **Non-linear weights** (norm/bias/scale) keep `.weight`/`.bias` names. +4. **Remaining HF weights** are loaded after GGUF to fill gaps. + ## FP8 Loading Path 1. Online FP8: `quantization="fp8"` or `quantization_config={"method":"fp8", "ignored_layers": [...]}`. 2. Native FP8: `quantization_config={"method":"fp8", "is_checkpoint_fp8_serialized": True}` to load an FP8-serialized checkpoint. From f39760b5b1c8dab87d8f1a7dbb50670dc411a456 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 09:26:56 +0800 Subject: [PATCH 18/62] support gguf fp8 add design doc 4 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/quantization/fp8_online.md | 75 ------------------- 1 file changed, 75 deletions(-) delete mode 100644 docs/user_guide/diffusion/quantization/fp8_online.md diff --git a/docs/user_guide/diffusion/quantization/fp8_online.md b/docs/user_guide/diffusion/quantization/fp8_online.md deleted file mode 100644 index 65a6329690..0000000000 --- a/docs/user_guide/diffusion/quantization/fp8_online.md +++ /dev/null @@ -1,75 +0,0 @@ -# Online FP8 Quantization - -## Overview - -Online FP8 converts BF16/FP16 weights to FP8 at model load time. No calibration or pre-quantized checkpoint needed. - -Depending on the model, either all layers can be quantized, or some sensitive layers should stay in BF16. See the [per-model table](#supported-models) for which case applies. - -## Configuration - -1. **Python API**: set `quantization="fp8"`. To skip sensitive layers, use `quantization_config` with `ignored_layers`. - -```python -from vllm_omni import Omni -from vllm_omni.inputs.data import OmniDiffusionSamplingParams - -# All layers quantized -omni = Omni(model="", quantization="fp8") - -# Skip sensitive layers -omni = Omni( - model="", - quantization_config={ - "method": "fp8", - "ignored_layers": [""], - }, -) - -outputs = omni.generate( - "A cat sitting on a windowsill", - OmniDiffusionSamplingParams(num_inference_steps=50), -) -``` - -2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. - -```bash -# All layers -python text_to_image.py --model --quantization fp8 - -# Skip sensitive layers -python text_to_image.py --model --quantization fp8 --ignored-layers "img_mlp" - -# Online serving -vllm serve --omni --quantization fp8 -``` - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `method` | str | — | Quantization method (`"fp8"`) | -| `ignored_layers` | list[str] | `[]` | Layer name patterns to keep in BF16 | -| `activation_scheme` | str | `"dynamic"` | `"dynamic"` (no calibration) or `"static"` | -| `weight_block_size` | list[int] \| None | `None` | Block size for block-wise weight quantization | - -The available `ignored_layers` names depend on the model architecture (e.g., `to_qkv`, `to_out`, `img_mlp`, `txt_mlp`). Consult the transformer source for your target model. - -## Supported Models - -| Model | HF Models | Recommendation | `ignored_layers` | -|-------|-----------|---------------|------------------| -| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | All layers | None | -| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Skip sensitive layers | `img_mlp` | - -## Combining with Other Features - -FP8 quantization can be combined with cache acceleration: - -```python -omni = Omni( - model="", - quantization="fp8", - cache_backend="tea_cache", - cache_config={"rel_l1_thresh": 0.2}, -) -``` From f63d509fe3c31a92ff78e3f021bd63cea7a52526 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 09:45:34 +0800 Subject: [PATCH 19/62] support gguf fp8 add design doc 5 Signed-off-by: David Chen <530634352@qq.com> --- docs/user_guide/diffusion/quantization/fp8.md | 11 ++++++++--- .../diffusion/quantization/gguf_fp8_design.md | 2 +- docs/user_guide/diffusion/quantization/overview.md | 1 + .../offline_inference/text_to_image/text_to_image.py | 9 --------- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md index 2c065546bf..61557faa75 100644 --- a/docs/user_guide/diffusion/quantization/fp8.md +++ b/docs/user_guide/diffusion/quantization/fp8.md @@ -1,4 +1,4 @@ -# FP8 Quantization +# FP8 Quantization ## Overview @@ -34,7 +34,8 @@ outputs = omni.generate( ) ``` -2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. +2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. You can also pass +`--quantization-config` as a JSON string for more control. ```bash # All layers @@ -45,11 +46,14 @@ python text_to_image.py --model --quantization fp8 --ignored-layers # Online serving vllm serve --omni --quantization fp8 + +# Online serving with quantization-config (same effect, explicit JSON) +vllm serve --omni --quantization-config '{"method":"fp8","ignored_layers":["img_mlp"]}' ``` | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `method` | str | — | Quantization method (`"fp8"`) | +| `method` | str | `None` | Quantization method (`"fp8"`) | | `ignored_layers` | list[str] | `[]` | Layer name patterns to keep in BF16 | | `activation_scheme` | str | `"dynamic"` | `"dynamic"` (no calibration) or `"static"` | | `weight_block_size` | list[int] \| None | `None` | Block size for block-wise weight quantization | @@ -75,3 +79,4 @@ omni = Omni( cache_config={"rel_l1_thresh": 0.2}, ) ``` + diff --git a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md index 8669ca7a60..ce0aa23e94 100644 --- a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md +++ b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md @@ -1,6 +1,6 @@ # Diffusion Quantization Design: Native GGUF + FP8 (Online and Native) -Date: 2026-02-11 +Date: 2026-02-12 ## Goals 1. Reuse vLLM quantization configs and weight loaders as much as possible. diff --git a/docs/user_guide/diffusion/quantization/overview.md b/docs/user_guide/diffusion/quantization/overview.md index 7dede292fc..6526f3843d 100644 --- a/docs/user_guide/diffusion/quantization/overview.md +++ b/docs/user_guide/diffusion/quantization/overview.md @@ -7,6 +7,7 @@ vLLM-Omni supports quantization of DiT linear layers to reduce memory usage and | Method | Guide | |--------|-------| | FP8 | [FP8](fp8.md) | +| GGUF | [GGUF + FP8 Design](gguf_fp8_design.md) | ## Device Compatibility diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index b544fcbefe..6a1dea646d 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -119,15 +119,6 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of ready layers (blocks) to keep on GPU during generation.", ) - parser.add_argument( - "--quantization", - type=str, - default=None, - choices=["fp8"], - help="Quantization method for the transformer. " - "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). " - "Default: None (no quantization, uses BF16).", - ) parser.add_argument( "--quantization", type=str, From ceb8c11fbdacbe4b15120869fd8fbd46bbcab6de Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 09:57:35 +0800 Subject: [PATCH 20/62] support gguf fp8 add qwen-image Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/quantization/gguf_fp8_design.md | 1 + .../model_loader/gguf_adapters/__init__.py | 10 +- .../model_loader/gguf_adapters/qwen_image.py | 255 ++++++++++++++++++ 3 files changed, 264 insertions(+), 2 deletions(-) create mode 100644 vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py diff --git a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md index ce0aa23e94..1b7f269c92 100644 --- a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md +++ b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md @@ -165,6 +165,7 @@ Notes: Adapter paths: - Base: `vllm_omni/diffusion/model_loader/gguf_adapters/base.py` +- Qwen-Image: `vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py` - Flux2-Klein: `vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py` ## Flux2-Klein GGUF Mapping (Key Rules) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py index 737132ead6..d7f01fb3c5 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -3,13 +3,19 @@ from .base import GGUFAdapter from .flux2_klein import Flux2KleinGGUFAdapter +from .qwen_image import QwenImageGGUFAdapter def get_gguf_adapter(gguf_file: str, model, source, od_config) -> GGUFAdapter: - for adapter_cls in (Flux2KleinGGUFAdapter,): + for adapter_cls in (QwenImageGGUFAdapter, Flux2KleinGGUFAdapter): if adapter_cls.is_compatible(od_config, model, source): return adapter_cls(gguf_file, model, source, od_config) return GGUFAdapter(gguf_file, model, source, od_config) -__all__ = ["GGUFAdapter", "Flux2KleinGGUFAdapter", "get_gguf_adapter"] +__all__ = [ + "GGUFAdapter", + "Flux2KleinGGUFAdapter", + "QwenImageGGUFAdapter", + "get_gguf_adapter", +] diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py new file mode 100644 index 0000000000..298e5c2044 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from .base import GGUFAdapter + + +@dataclass +class _MappedTensor: + name: str + tensor: Any + tensor_type: Any + + +class QwenImageGGUFAdapter(GGUFAdapter): + """GGUF adapter for Qwen-Image models with QKV shard support.""" + + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + model_class = od_config.model_class_name or "" + if model_class.startswith("QwenImage"): + return True + cfg = od_config.tf_model_config + if cfg is not None: + model_type = str(cfg.get("model_type", "")).lower() + if model_type.startswith("qwen_image"): + return True + return False + + def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: + try: + import gguf # type: ignore + except Exception as exc: # pragma: no cover - dependency error + raise RuntimeError( + "GGUF support requires the 'gguf' package to be installed." + ) from exc + + reader = gguf.GGUFReader(self.gguf_file) + gguf_name_map = self._build_gguf_name_map(reader) + allowed_names = self._build_allowed_names() + param_names = self._build_param_names() + mapped: list[_MappedTensor] = [] + + for tensor in reader.tensors: + mapped_name = gguf_name_map.get(tensor.name) + if mapped_name is None: + mapped_name = self._normalize_name(tensor.name) + if ( + mapped_name not in allowed_names + and self._resolve_linear_qweight(mapped_name, param_names) is None + ): + continue + mapped.append( + _MappedTensor( + name=mapped_name, + tensor=tensor, + tensor_type=tensor.tensor_type, + ) + ) + + if not mapped: + raise RuntimeError( + "No GGUF tensors were mapped for Qwen-Image GGUF loader. " + "Please verify the GGUF file and model structure." + ) + + for item in mapped: + linear_qweight = self._resolve_linear_qweight(item.name, param_names) + if linear_qweight is None: + continue + weight_type_name = linear_qweight.replace("qweight", "qweight_type") + yield weight_type_name, torch.tensor(item.tensor_type) + + for item in mapped: + weight = item.tensor.data + weight_type = item.tensor_type + linear_qweight = self._resolve_linear_qweight(item.name, param_names) + if linear_qweight is not None: + name = linear_qweight + else: + name = item.name + + if weight_type.name == "BF16" and weight.dtype == np.uint8: + weight = weight.view(np.uint16) + if reader.byte_order == "S": + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) + + yield name, param + + def _normalize_name(self, name: str) -> str: + if name.endswith(".scale"): + name = name[:-6] + ".weight" + if name.endswith("_weight"): + name = name[:-7] + ".weight" + if ".to_out.0." in name: + name = name.replace(".to_out.0.", ".to_out.") + return name + + def _build_allowed_names(self) -> set[str]: + prefix = getattr(self.source, "prefix", "") + target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + allowed = {name for name, _ in target.named_parameters()} + allowed.update(name for name, _ in target.named_buffers()) + for name in list(allowed): + if name.endswith(".qweight"): + allowed.add(name.replace(".qweight", ".weight")) + elif name.endswith(".qweight_type"): + allowed.add(name.replace(".qweight_type", ".weight")) + + virtual_names = set() + for name in allowed: + if ".to_qkv." in name: + virtual_names.add(name.replace(".to_qkv.", ".to_q.")) + virtual_names.add(name.replace(".to_qkv.", ".to_k.")) + virtual_names.add(name.replace(".to_qkv.", ".to_v.")) + if ".add_kv_proj." in name: + virtual_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) + if ".to_out." in name: + virtual_names.add(name.replace(".to_out.", ".to_out.0.")) + allowed.update(virtual_names) + return allowed + + def _build_param_names(self) -> set[str]: + prefix = getattr(self.source, "prefix", "") + target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + return {name for name, _ in target.named_parameters()} + + def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: + if not name.endswith(".weight"): + return None + if ".to_out.0." in name: + name = name.replace(".to_out.0.", ".to_out.") + for shard_token in ( + ".to_q.", + ".to_k.", + ".to_v.", + ".add_q_proj.", + ".add_k_proj.", + ".add_v_proj.", + ): + if shard_token in name: + return name.replace(".weight", ".qweight") + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + return None + + def _build_gguf_name_map(self, reader) -> dict[str, str]: + try: + import gguf # type: ignore + except Exception as exc: # pragma: no cover - dependency error + raise RuntimeError( + "GGUF support requires the 'gguf' package to be installed." + ) from exc + + gguf_tensor_names = {tensor.name for tensor in reader.tensors} + + def resolve_model_type() -> str: + cfg = self.od_config.tf_model_config + model_type = None + if cfg is not None: + model_type = cfg.get("model_type") + if model_type: + return model_type + model_class = self.od_config.model_class_name or "" + if model_class.startswith("QwenImage"): + return "qwen_image" + raise ValueError("Cannot infer gguf model_type for Qwen-Image.") + + def resolve_arch(model_type: str): + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + return key + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + + def resolve_num_layers(target_module: torch.nn.Module) -> int: + if hasattr(target_module, "transformer_blocks"): + return len(getattr(target_module, "transformer_blocks")) + cfg = self.od_config.tf_model_config + if cfg is not None: + for key in ("num_hidden_layers", "num_layers", "n_layers"): + value = cfg.get(key) + if isinstance(value, int) and value > 0: + return value + raise ValueError("Cannot infer gguf num_layers for Qwen-Image.") + + def get_target_module(root: torch.nn.Module, prefix: str) -> torch.nn.Module: + if not prefix: + return root + prefix = prefix.rstrip(".") + if hasattr(root, "get_submodule"): + return root.get_submodule(prefix) + current = root + for part in prefix.split("."): + current = getattr(current, part) + return current + + def split_name(name: str) -> tuple[str, str]: + if name.endswith("_weight"): + return name[:-7], "weight" + if "." in name: + base, suffix = name.rsplit(".", 1) + return base, suffix + return name, "" + + model_type = resolve_model_type() + arch = resolve_arch(model_type) + target_module = get_target_module(self.model, self.source.prefix) + num_layers = resolve_num_layers(target_module) + name_map = gguf.get_tensor_name_map(arch, num_layers) + + candidate_names = {name for name, _ in target_module.named_parameters()} + candidate_names.update(name for name, _ in target_module.named_buffers()) + for name in list(candidate_names): + if ".to_qkv." in name: + candidate_names.add(name.replace(".to_qkv.", ".to_q.")) + candidate_names.add(name.replace(".to_qkv.", ".to_k.")) + candidate_names.add(name.replace(".to_qkv.", ".to_v.")) + if ".add_kv_proj." in name: + candidate_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) + candidate_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) + candidate_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) + if ".to_out." in name: + candidate_names.add(name.replace(".to_out.", ".to_out.0.")) + + gguf_to_model_map: dict[str, str] = {} + for name in candidate_names: + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: + continue + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + + return gguf_to_model_map From c5b1e75d49e712e86217fc893ccf7e1d520cc6fc Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 10:33:32 +0800 Subject: [PATCH 21/62] support gguf fp8 add z-image Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/quantization/gguf_fp8_design.md | 1 + .../model_loader/gguf_adapters/__init__.py | 4 +- .../model_loader/gguf_adapters/z_image.py | 252 ++++++++++++++++++ 3 files changed, 256 insertions(+), 1 deletion(-) create mode 100644 vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py diff --git a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md index 1b7f269c92..d1ad4d1785 100644 --- a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md +++ b/docs/user_guide/diffusion/quantization/gguf_fp8_design.md @@ -166,6 +166,7 @@ Notes: Adapter paths: - Base: `vllm_omni/diffusion/model_loader/gguf_adapters/base.py` - Qwen-Image: `vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py` +- Z-Image: `vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py` - Flux2-Klein: `vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py` ## Flux2-Klein GGUF Mapping (Key Rules) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py index d7f01fb3c5..7abceda31b 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -4,10 +4,11 @@ from .base import GGUFAdapter from .flux2_klein import Flux2KleinGGUFAdapter from .qwen_image import QwenImageGGUFAdapter +from .z_image import ZImageGGUFAdapter def get_gguf_adapter(gguf_file: str, model, source, od_config) -> GGUFAdapter: - for adapter_cls in (QwenImageGGUFAdapter, Flux2KleinGGUFAdapter): + for adapter_cls in (QwenImageGGUFAdapter, ZImageGGUFAdapter, Flux2KleinGGUFAdapter): if adapter_cls.is_compatible(od_config, model, source): return adapter_cls(gguf_file, model, source, od_config) return GGUFAdapter(gguf_file, model, source, od_config) @@ -17,5 +18,6 @@ def get_gguf_adapter(gguf_file: str, model, source, od_config) -> GGUFAdapter: "GGUFAdapter", "Flux2KleinGGUFAdapter", "QwenImageGGUFAdapter", + "ZImageGGUFAdapter", "get_gguf_adapter", ] diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py new file mode 100644 index 0000000000..eedf136120 --- /dev/null +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from .base import GGUFAdapter + + +@dataclass +class _MappedTensor: + name: str + tensor: Any + tensor_type: Any + + +class ZImageGGUFAdapter(GGUFAdapter): + """GGUF adapter for Z-Image models with QKV/FFN shard support.""" + + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + model_class = od_config.model_class_name or "" + if model_class.startswith("ZImage"): + return True + cfg = od_config.tf_model_config + if cfg is not None: + model_type = str(cfg.get("model_type", "")).lower() + if model_type in {"z_image", "zimage", "z-image"}: + return True + return False + + def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: + try: + import gguf # type: ignore + except Exception as exc: # pragma: no cover - dependency error + raise RuntimeError( + "GGUF support requires the 'gguf' package to be installed." + ) from exc + + reader = gguf.GGUFReader(self.gguf_file) + gguf_name_map = self._build_gguf_name_map(reader) + allowed_names = self._build_allowed_names() + param_names = self._build_param_names() + mapped: list[_MappedTensor] = [] + + for tensor in reader.tensors: + mapped_name = gguf_name_map.get(tensor.name) + if mapped_name is None: + mapped_name = self._normalize_name(tensor.name) + if ( + mapped_name not in allowed_names + and self._resolve_linear_qweight(mapped_name, param_names) is None + ): + continue + mapped.append( + _MappedTensor( + name=mapped_name, + tensor=tensor, + tensor_type=tensor.tensor_type, + ) + ) + + if not mapped: + raise RuntimeError( + "No GGUF tensors were mapped for Z-Image GGUF loader. " + "Please verify the GGUF file and model structure." + ) + + for item in mapped: + linear_qweight = self._resolve_linear_qweight(item.name, param_names) + if linear_qweight is None: + continue + weight_type_name = linear_qweight.replace("qweight", "qweight_type") + yield weight_type_name, torch.tensor(item.tensor_type) + + for item in mapped: + weight = item.tensor.data + weight_type = item.tensor_type + linear_qweight = self._resolve_linear_qweight(item.name, param_names) + if linear_qweight is not None: + name = linear_qweight + else: + name = item.name + + if weight_type.name == "BF16" and weight.dtype == np.uint8: + weight = weight.view(np.uint16) + if reader.byte_order == "S": + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) + + yield name, param + + def _normalize_name(self, name: str) -> str: + if name.endswith(".scale"): + name = name[:-6] + ".weight" + if name.endswith("_weight"): + name = name[:-7] + ".weight" + if ".to_out.0." in name: + name = name.replace(".to_out.0.", ".to_out.") + return name + + def _build_allowed_names(self) -> set[str]: + prefix = getattr(self.source, "prefix", "") + target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + allowed = {name for name, _ in target.named_parameters()} + allowed.update(name for name, _ in target.named_buffers()) + for name in list(allowed): + if name.endswith(".qweight"): + allowed.add(name.replace(".qweight", ".weight")) + elif name.endswith(".qweight_type"): + allowed.add(name.replace(".qweight_type", ".weight")) + + virtual_names = set() + for name in allowed: + if ".to_qkv." in name: + virtual_names.add(name.replace(".to_qkv.", ".to_q.")) + virtual_names.add(name.replace(".to_qkv.", ".to_k.")) + virtual_names.add(name.replace(".to_qkv.", ".to_v.")) + if ".w13." in name: + virtual_names.add(name.replace(".w13.", ".w1.")) + virtual_names.add(name.replace(".w13.", ".w3.")) + if ".to_out." in name: + virtual_names.add(name.replace(".to_out.", ".to_out.0.")) + allowed.update(virtual_names) + return allowed + + def _build_param_names(self) -> set[str]: + prefix = getattr(self.source, "prefix", "") + target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + return {name for name, _ in target.named_parameters()} + + def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: + if not name.endswith(".weight"): + return None + if ".to_out.0." in name: + name = name.replace(".to_out.0.", ".to_out.") + for shard_token in ( + ".to_q.", + ".to_k.", + ".to_v.", + ".w1.", + ".w3.", + ): + if shard_token in name: + return name.replace(".weight", ".qweight") + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + return None + + def _build_gguf_name_map(self, reader) -> dict[str, str]: + try: + import gguf # type: ignore + except Exception as exc: # pragma: no cover - dependency error + raise RuntimeError( + "GGUF support requires the 'gguf' package to be installed." + ) from exc + + gguf_tensor_names = {tensor.name for tensor in reader.tensors} + + def resolve_model_type() -> str: + cfg = self.od_config.tf_model_config + model_type = None + if cfg is not None: + model_type = cfg.get("model_type") + if model_type: + return model_type + model_class = self.od_config.model_class_name or "" + if model_class.startswith("ZImage"): + return "z_image" + raise ValueError("Cannot infer gguf model_type for Z-Image.") + + def resolve_arch(model_type: str): + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + return key + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + + def resolve_num_layers(target_module: torch.nn.Module) -> int: + if hasattr(target_module, "layers"): + return len(getattr(target_module, "layers")) + cfg = self.od_config.tf_model_config + if cfg is not None: + for key in ("num_hidden_layers", "num_layers", "n_layers"): + value = cfg.get(key) + if isinstance(value, int) and value > 0: + return value + raise ValueError("Cannot infer gguf num_layers for Z-Image.") + + def get_target_module(root: torch.nn.Module, prefix: str) -> torch.nn.Module: + if not prefix: + return root + prefix = prefix.rstrip(".") + if hasattr(root, "get_submodule"): + return root.get_submodule(prefix) + current = root + for part in prefix.split("."): + current = getattr(current, part) + return current + + def split_name(name: str) -> tuple[str, str]: + if name.endswith("_weight"): + return name[:-7], "weight" + if "." in name: + base, suffix = name.rsplit(".", 1) + return base, suffix + return name, "" + + model_type = resolve_model_type() + arch = resolve_arch(model_type) + target_module = get_target_module(self.model, self.source.prefix) + num_layers = resolve_num_layers(target_module) + name_map = gguf.get_tensor_name_map(arch, num_layers) + + candidate_names = {name for name, _ in target_module.named_parameters()} + candidate_names.update(name for name, _ in target_module.named_buffers()) + for name in list(candidate_names): + if ".to_qkv." in name: + candidate_names.add(name.replace(".to_qkv.", ".to_q.")) + candidate_names.add(name.replace(".to_qkv.", ".to_k.")) + candidate_names.add(name.replace(".to_qkv.", ".to_v.")) + if ".w13." in name: + candidate_names.add(name.replace(".w13.", ".w1.")) + candidate_names.add(name.replace(".w13.", ".w3.")) + if ".to_out." in name: + candidate_names.add(name.replace(".to_out.", ".to_out.0.")) + + gguf_to_model_map: dict[str, str] = {} + for name in candidate_names: + base_name, suffix = split_name(name) + gguf_base = name_map.get_name(base_name) + if gguf_base is None: + continue + candidates = [] + if suffix: + candidates.append(f"{gguf_base}.{suffix}") + if suffix == "weight": + candidates.append(f"{gguf_base}.scale") + else: + candidates.append(gguf_base) + gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) + if gguf_name is None: + continue + gguf_to_model_map[gguf_name] = name + + return gguf_to_model_map From 599bafb95233c9b5f7f71f184add16e1fd397b2b Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 10:40:18 +0800 Subject: [PATCH 22/62] support gguf only Signed-off-by: David Chen <530634352@qq.com> --- docs/user_guide/diffusion/quantization/fp8.md | 11 +- .../{gguf_fp8_design.md => gguf.md} | 110 +++--------------- .../diffusion/quantization/overview.md | 2 +- vllm_omni/diffusion/quantization/fp8.py | 4 +- 4 files changed, 19 insertions(+), 108 deletions(-) rename docs/user_guide/diffusion/quantization/{gguf_fp8_design.md => gguf.md} (69%) diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md index 61557faa75..3a3132677e 100644 --- a/docs/user_guide/diffusion/quantization/fp8.md +++ b/docs/user_guide/diffusion/quantization/fp8.md @@ -34,8 +34,7 @@ outputs = omni.generate( ) ``` -2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. You can also pass -`--quantization-config` as a JSON string for more control. +2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. ```bash # All layers @@ -46,14 +45,11 @@ python text_to_image.py --model --quantization fp8 --ignored-layers # Online serving vllm serve --omni --quantization fp8 - -# Online serving with quantization-config (same effect, explicit JSON) -vllm serve --omni --quantization-config '{"method":"fp8","ignored_layers":["img_mlp"]}' ``` | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `method` | str | `None` | Quantization method (`"fp8"`) | +| `method` | str | — | Quantization method (`"fp8"`) | | `ignored_layers` | list[str] | `[]` | Layer name patterns to keep in BF16 | | `activation_scheme` | str | `"dynamic"` | `"dynamic"` (no calibration) or `"static"` | | `weight_block_size` | list[int] \| None | `None` | Block size for block-wise weight quantization | @@ -78,5 +74,4 @@ omni = Omni( cache_backend="tea_cache", cache_config={"rel_l1_thresh": 0.2}, ) -``` - +``` \ No newline at end of file diff --git a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md b/docs/user_guide/diffusion/quantization/gguf.md similarity index 69% rename from docs/user_guide/diffusion/quantization/gguf_fp8_design.md rename to docs/user_guide/diffusion/quantization/gguf.md index d1ad4d1785..3b18a31be0 100644 --- a/docs/user_guide/diffusion/quantization/gguf_fp8_design.md +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -1,23 +1,23 @@ -# Diffusion Quantization Design: Native GGUF + FP8 (Online and Native) +# Diffusion Quantization Design: Native GGUF Date: 2026-02-12 ## Goals 1. Reuse vLLM quantization configs and weight loaders as much as possible. -2. Add native GGUF and FP8 support to diffusion transformers without changing model definitions. +2. Add native GGUF support to diffusion transformers without changing model definitions. 3. Keep user-facing knobs minimal and consistent across offline and online flows. ## Scope -1. Models: Qwen-Image and Flux2-klein are first-class targets. +1. Models: Qwen-Image, Z-Image, and Flux2-klein. 2. Components: diffusion transformer weights, loader paths, and quantization configs. -3. Modes: native GGUF, online FP8, native FP8 (pre-serialized FP8 checkpoint). +3. Modes: native GGUF (transformer-only weights). ## Architecture Overview 1. `OmniDiffusionConfig` accepts `quantization` or `quantization_config`. -2. Diffusion quantization wrappers (`DiffusionGgufConfig`, `DiffusionFp8Config`) produce vLLM `QuantizationConfig` objects for linear layers. +2. Diffusion quantization wrapper (`DiffusionGgufConfig`) produces vLLM `QuantizationConfig` objects for linear layers. 3. `DiffusersPipelineLoader` branches on quantization method and loads either HF weights or GGUF weights for the transformer. 4. GGUF transformer loading is routed through model-specific adapters (e.g., Flux2Klein). -4. vLLM GGUF path uses `GGUFConfig` and `GGUFLinearMethod` for matmul; FP8 uses `Fp8Config` (online) or `is_checkpoint_fp8_serialized` for native FP8. +5. vLLM GGUF path uses `GGUFConfig` and `GGUFLinearMethod` for matmul. ## Call Chain (Offline) ``` @@ -39,7 +39,7 @@ DiffusionModelRunner DiffusersPipelineLoader | v -Pipeline.forward (Flux2/Qwen) +Pipeline.forward (Flux2/Qwen/Z-Image) | v DiffusionEngine @@ -83,15 +83,12 @@ Client ## Call Chain (GGUF Operator Path) ``` -Pipeline.forward (Flux2/Qwen) +Pipeline.forward (Flux2/Qwen/Z-Image) | v Transformer blocks | v -Flux2Attention / Flux2ParallelSelfAttention - | - v QKVParallelLinear / ColumnParallelLinear / RowParallelLinear | v @@ -117,51 +114,19 @@ Notes: 1. GGUF linear inputs are flattened to 2D inside `GGUFLinearMethod.apply` and reshaped back. 2. As of 2026-02-10 in this branch, `_fused_mul_mat_gguf` is forced to the dequantize path. -## Call Chain (FP8 Operator Path) -``` -Pipeline.forward (Flux2/Qwen) - | - v -Transformer blocks - | - v -QKVParallelLinear / ColumnParallelLinear / RowParallelLinear - | - v -LinearBase.forward - | - v -QuantMethod.apply (Fp8LinearMethod.apply or Fp8OnlineLinearMethod.apply) - | - +--> apply_fp8_marlin_linear (weight-only path on older GPUs) - | - +--> W8A8BlockFp8LinearOp.apply (block quant path) - | - +--> fp8_linear.apply_weights - | - v - init_fp8_linear_kernel - | - v - FlashInferFP8ScaledMMLinearKernel / CutlassFP8ScaledMMLinearKernel / - Torch FP8 ScaledMM kernels -``` - -Notes: -1. Online FP8 differs at load time; runtime operator path matches native FP8. -2. The kernel selection is platform and capability dependent. - ## GGUF Weight Loading Path (Transformer-Only) 1. `DiffusersPipelineLoader.load_model` detects `quantization_config.method == "gguf"`. 2. `gguf_model` is resolved as one of: local file, URL, `repo/file.gguf`, or `repo:quant_type`. 3. GGUF weights are routed through adapters in `vllm_omni/diffusion/model_loader/gguf_adapters/`. -4. Name mapping is applied per-architecture (Qwen-Image, Flux2Klein). -4. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. +4. Name mapping is applied per-architecture (Qwen-Image, Z-Image, Flux2Klein). +5. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. ## GGUF Adapter Design 1. `GGUFAdapter` (base) implements default gguf-py tensor name mapping. 2. `Flux2KleinGGUFAdapter` implements Flux2-Klein remapping + qkv split + adaLN swap. -3. `get_gguf_adapter(...)` selects the adapter by model class/config and returns an iterator of `(name, tensor)`. +3. `QwenImageGGUFAdapter` implements Qwen-Image qkv shard handling and linear qweight routing. +4. `ZImageGGUFAdapter` implements Z-Image qkv + ffn shard handling and linear qweight routing. +5. `get_gguf_adapter(...)` selects the adapter by model class/config and returns an iterator of `(name, tensor)`. Adapter paths: - Base: `vllm_omni/diffusion/model_loader/gguf_adapters/base.py` @@ -213,10 +178,6 @@ Adapter paths: 3. **Non-linear weights** (norm/bias/scale) keep `.weight`/`.bias` names. 4. **Remaining HF weights** are loaded after GGUF to fill gaps. -## FP8 Loading Path -1. Online FP8: `quantization="fp8"` or `quantization_config={"method":"fp8", "ignored_layers": [...]}`. -2. Native FP8: `quantization_config={"method":"fp8", "is_checkpoint_fp8_serialized": True}` to load an FP8-serialized checkpoint. - ## User Usage (Offline) ### Baseline BF16 @@ -253,43 +214,8 @@ Notes for GGUF: 1. Many GGUF repos do not ship `model_index.json` and configs. Use the base repo for `--model` and only pass the GGUF file via `--gguf-model`. 2. `gguf_model` supports local path, URL, `repo/file.gguf`, or `repo:quant_type`. -### Online FP8 (Runtime Quantization) -```bash -python examples/offline_inference/text_to_image/text_to_image.py \ - --model Qwen/Qwen-Image \ - --quantization fp8 \ - --prompt "a cup of coffee on the table" \ - --height 1024 \ - --width 1024 -``` - -### Native FP8 (Serialized Checkpoint) -Use the Python API to pass `is_checkpoint_fp8_serialized`. -```python -from vllm_omni import Omni -from vllm_omni.inputs.data import OmniDiffusionSamplingParams - -omni = Omni( - model="/path/to/fp8-checkpoint", - quantization_config={ - "method": "fp8", - "is_checkpoint_fp8_serialized": True, - }, -) - -outputs = omni.generate( - "a cup of coffee on the table", - OmniDiffusionSamplingParams(num_inference_steps=4), -) -``` - ## User Usage (Online) -### Start Server (Online FP8) -```bash -vllm serve Qwen/Qwen-Image --omni --port 8000 --quantization fp8 -``` - ### Start Server (Native GGUF via CLI) ```bash vllm serve /workspace/models/black-forest-labs/FLUX.2-klein-4B \ @@ -298,14 +224,6 @@ vllm serve /workspace/models/black-forest-labs/FLUX.2-klein-4B \ --quantization-config '{"method":"gguf","gguf_model":"/workspace/models/unsloth/FLUX.2-klein-4B-GGUF/flux-2-klein-4b-Q8_0.gguf"}' ``` -### Start Server (Native FP8 via CLI) -```bash -vllm serve /path/to/fp8-checkpoint \ - --omni \ - --port 8000 \ - --quantization-config '{"method":"fp8","is_checkpoint_fp8_serialized":true}' -``` - ### Online Request (Images API) ```bash curl -X POST http://localhost:8000/v1/images/generations \ @@ -320,5 +238,5 @@ curl -X POST http://localhost:8000/v1/images/generations \ ## Validation Checklist 1. Fix the date in logs and docs for comparisons. -2. Use the same prompt, size, steps, and seed for BF16 vs GGUF/FP8 comparisons. +2. Use the same prompt, size, steps, and seed for BF16 vs GGUF comparisons. 3. Expect accuracy differences for Q8_0 GGUF; verify mapping with F16/BF16 GGUF if needed. diff --git a/docs/user_guide/diffusion/quantization/overview.md b/docs/user_guide/diffusion/quantization/overview.md index 6526f3843d..e4ce69677c 100644 --- a/docs/user_guide/diffusion/quantization/overview.md +++ b/docs/user_guide/diffusion/quantization/overview.md @@ -7,7 +7,7 @@ vLLM-Omni supports quantization of DiT linear layers to reduce memory usage and | Method | Guide | |--------|-------| | FP8 | [FP8](fp8.md) | -| GGUF | [GGUF + FP8 Design](gguf_fp8_design.md) | +| GGUF | [GGUF](gguf.md) | ## Device Compatibility diff --git a/vllm_omni/diffusion/quantization/fp8.py b/vllm_omni/diffusion/quantization/fp8.py index 963dd6c3bc..f07bf430be 100644 --- a/vllm_omni/diffusion/quantization/fp8.py +++ b/vllm_omni/diffusion/quantization/fp8.py @@ -36,16 +36,14 @@ def __init__( activation_scheme: str = "dynamic", weight_block_size: list[int] | None = None, ignored_layers: list[str] | None = None, - is_checkpoint_fp8_serialized: bool = False, ): self.activation_scheme = activation_scheme self.weight_block_size = weight_block_size self.ignored_layers = ignored_layers or [] - self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized # Create underlying vLLM FP8 config self._vllm_config = Fp8Config( - is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + is_checkpoint_fp8_serialized=False, activation_scheme=activation_scheme, weight_block_size=weight_block_size, ignored_layers=ignored_layers, From 9f8438761f4b7cd69b9d9146e70d2734d79ea138 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 10:42:32 +0800 Subject: [PATCH 23/62] support gguf 1 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/quantization/fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/quantization/fp8.py b/vllm_omni/diffusion/quantization/fp8.py index f07bf430be..68abf9c229 100644 --- a/vllm_omni/diffusion/quantization/fp8.py +++ b/vllm_omni/diffusion/quantization/fp8.py @@ -43,7 +43,7 @@ def __init__( # Create underlying vLLM FP8 config self._vllm_config = Fp8Config( - is_checkpoint_fp8_serialized=False, + is_checkpoint_fp8_serialized=False, # Online quantization from BF16 activation_scheme=activation_scheme, weight_block_size=weight_block_size, ignored_layers=ignored_layers, From 68e5345e77b157f53941ef86173be8a92d085ee5 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 11:02:53 +0800 Subject: [PATCH 24/62] support gguf fp8 add qwen-image Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/model_loader/gguf_adapters/qwen_image.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py index 298e5c2044..472b93cc70 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py @@ -215,7 +215,12 @@ def split_name(name: str) -> tuple[str, str]: return name, "" model_type = resolve_model_type() - arch = resolve_arch(model_type) + try: + arch = resolve_arch(model_type) + except RuntimeError: + # Fallback: some gguf versions may not register qwen_image arch. + # In that case, rely on direct tensor names from the GGUF file. + return {} target_module = get_target_module(self.model, self.source.prefix) num_layers = resolve_num_layers(target_module) name_map = gguf.get_tensor_name_map(arch, num_layers) From 3795dc528d5011ca2272d82b8775719d71eb9065 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 11:13:08 +0800 Subject: [PATCH 25/62] support gguf fp8 add qwen-image 2 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/model_loader/gguf_adapters/qwen_image.py | 9 +++++---- .../diffusion/model_loader/gguf_adapters/z_image.py | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py index 472b93cc70..76fe728355 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py @@ -51,10 +51,11 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: mapped_name = gguf_name_map.get(tensor.name) if mapped_name is None: mapped_name = self._normalize_name(tensor.name) - if ( - mapped_name not in allowed_names - and self._resolve_linear_qweight(mapped_name, param_names) is None - ): + linear_qweight = self._resolve_linear_qweight(mapped_name, param_names) + if mapped_name not in allowed_names and linear_qweight is None: + continue + if linear_qweight is None and tensor.tensor_type.name not in ("F32", "BF16", "F16"): + # Skip quantized tensors that map to non-quantized parameters. continue mapped.append( _MappedTensor( diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index eedf136120..e9365dddd8 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -51,10 +51,11 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: mapped_name = gguf_name_map.get(tensor.name) if mapped_name is None: mapped_name = self._normalize_name(tensor.name) - if ( - mapped_name not in allowed_names - and self._resolve_linear_qweight(mapped_name, param_names) is None - ): + linear_qweight = self._resolve_linear_qweight(mapped_name, param_names) + if mapped_name not in allowed_names and linear_qweight is None: + continue + if linear_qweight is None and tensor.tensor_type.name not in ("F32", "BF16", "F16"): + # Skip quantized tensors that map to non-quantized parameters. continue mapped.append( _MappedTensor( From 2d7a409c566cc5312fa8ebac38eb7165e83b415c Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 11:26:00 +0800 Subject: [PATCH 26/62] support gguf fp8 add note Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/quantization/gguf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_omni/diffusion/quantization/gguf.py b/vllm_omni/diffusion/quantization/gguf.py index ddedc3354e..e633c81489 100644 --- a/vllm_omni/diffusion/quantization/gguf.py +++ b/vllm_omni/diffusion/quantization/gguf.py @@ -24,6 +24,9 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + # Dequantize + GEMM path: torch.matmul multiplies over the last + # dimension and broadcasts leading dimensions, so no 2D flattening + # is required here. shard_id = layer.qweight.shard_id if shard_id: From 580a18ef9fd3ecb671eaf4be9586d8cc54ea800d Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 11:37:03 +0800 Subject: [PATCH 27/62] support gguf fp8 fix 1 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index e9365dddd8..01459b35bb 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -214,7 +214,12 @@ def split_name(name: str) -> tuple[str, str]: return name, "" model_type = resolve_model_type() - arch = resolve_arch(model_type) + try: + arch = resolve_arch(model_type) + except RuntimeError: + # Fallback: some gguf versions may not register z_image arch. + # In that case, rely on direct tensor names from the GGUF file. + return {} target_module = get_target_module(self.model, self.source.prefix) num_layers = resolve_num_layers(target_module) name_map = gguf.get_tensor_name_map(arch, num_layers) From bcc91f6925a9e22274ae8172f930bbc63c812945 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 11:44:52 +0800 Subject: [PATCH 28/62] support gguf fp8 fix z-image Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/gguf_adapters/z_image.py | 96 ++++++++++++++++--- 1 file changed, 81 insertions(+), 15 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 01459b35bb..930cc27a67 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -16,6 +16,7 @@ class _MappedTensor: name: str tensor: Any tensor_type: Any + row_slice: slice | None = None class ZImageGGUFAdapter(GGUFAdapter): @@ -48,22 +49,19 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: mapped: list[_MappedTensor] = [] for tensor in reader.tensors: - mapped_name = gguf_name_map.get(tensor.name) - if mapped_name is None: - mapped_name = self._normalize_name(tensor.name) - linear_qweight = self._resolve_linear_qweight(mapped_name, param_names) - if mapped_name not in allowed_names and linear_qweight is None: - continue - if linear_qweight is None and tensor.tensor_type.name not in ("F32", "BF16", "F16"): - # Skip quantized tensors that map to non-quantized parameters. - continue - mapped.append( - _MappedTensor( - name=mapped_name, - tensor=tensor, - tensor_type=tensor.tensor_type, + for mapped_tensor in self._map_tensor_name(tensor, gguf_name_map): + linear_qweight = self._resolve_linear_qweight( + mapped_tensor.name, param_names ) - ) + if mapped_tensor.name not in allowed_names and linear_qweight is None: + continue + if ( + linear_qweight is None + and tensor.tensor_type.name not in ("F32", "BF16", "F16") + ): + # Skip quantized tensors that map to non-quantized parameters. + continue + mapped.append(mapped_tensor) if not mapped: raise RuntimeError( @@ -80,6 +78,8 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: for item in mapped: weight = item.tensor.data + if item.row_slice is not None: + weight = weight[item.row_slice] weight_type = item.tensor_type linear_qweight = self._resolve_linear_qweight(item.name, param_names) if linear_qweight is not None: @@ -106,6 +106,72 @@ def _normalize_name(self, name: str) -> str: name = name.replace(".to_out.0.", ".to_out.") return name + def _get_patch_key(self) -> str: + prefix = getattr(self.source, "prefix", "") + target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + if hasattr(target, "all_x_embedder"): + keys = list(getattr(target, "all_x_embedder").keys()) + if "2-1" in keys: + return "2-1" + if keys: + return sorted(keys)[0] + return "2-1" + + def _apply_zimage_renames(self, name: str) -> str: + if name.startswith("model.diffusion_model."): + name = name.replace("model.diffusion_model.", "", 1) + + patch_key = self._get_patch_key() + if name.startswith("x_embedder.") and not name.startswith("all_x_embedder."): + name = name.replace("x_embedder.", f"all_x_embedder.{patch_key}.", 1) + if name.startswith("final_layer.") and not name.startswith("all_final_layer."): + name = name.replace("final_layer.", f"all_final_layer.{patch_key}.", 1) + + name = name.replace(".attention.out.bias", ".attention.to_out.0.bias") + name = name.replace(".attention.out.weight", ".attention.to_out.0.weight") + name = name.replace(".attention.k_norm.weight", ".attention.norm_k.weight") + name = name.replace(".attention.q_norm.weight", ".attention.norm_q.weight") + return name + + def _map_tensor_name(self, tensor, gguf_name_map: dict[str, str]) -> list[_MappedTensor]: + name = gguf_name_map.get(tensor.name) + if name is None: + name = self._normalize_name(tensor.name) + name = self._apply_zimage_renames(name) + + if ".attention.qkv.weight" in name: + weight = tensor.data + dim0 = weight.shape[0] + split = dim0 // 3 + return [ + _MappedTensor( + name=name.replace(".attention.qkv.weight", ".attention.to_q.weight"), + tensor=tensor, + tensor_type=tensor.tensor_type, + row_slice=slice(0, split), + ), + _MappedTensor( + name=name.replace(".attention.qkv.weight", ".attention.to_k.weight"), + tensor=tensor, + tensor_type=tensor.tensor_type, + row_slice=slice(split, 2 * split), + ), + _MappedTensor( + name=name.replace(".attention.qkv.weight", ".attention.to_v.weight"), + tensor=tensor, + tensor_type=tensor.tensor_type, + row_slice=slice(2 * split, 3 * split), + ), + ] + + return [ + _MappedTensor( + name=name, + tensor=tensor, + tensor_type=tensor.tensor_type, + ) + ] + def _build_allowed_names(self) -> set[str]: prefix = getattr(self.source, "prefix", "") target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model From 5dcf32d3bf4474e91b00f54958f14ebad3261eb5 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 14:09:12 +0800 Subject: [PATCH 29/62] support gguf fp8 fix z-image 2 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/model_loader/gguf_adapters/z_image.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 930cc27a67..197a6e14f4 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -205,8 +205,14 @@ def _build_param_names(self) -> set[str]: def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: if not name.endswith(".weight"): return None + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate if ".to_out.0." in name: - name = name.replace(".to_out.0.", ".to_out.") + alt_name = name.replace(".to_out.0.", ".to_out.") + candidate = alt_name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate for shard_token in ( ".to_q.", ".to_k.", @@ -216,9 +222,6 @@ def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | Non ): if shard_token in name: return name.replace(".weight", ".qweight") - candidate = name.replace(".weight", ".qweight") - if candidate in param_names: - return candidate return None def _build_gguf_name_map(self, reader) -> dict[str, str]: From 172dcf24a68d443006aee5738ad6636041ec8c0d Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 14:19:48 +0800 Subject: [PATCH 30/62] support gguf fp8 fix z-image 3 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/quantization/gguf.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/quantization/gguf.py b/vllm_omni/diffusion/quantization/gguf.py index e633c81489..5472c3bb48 100644 --- a/vllm_omni/diffusion/quantization/gguf.py +++ b/vllm_omni/diffusion/quantization/gguf.py @@ -4,13 +4,23 @@ import torch import gguf -from vllm.model_executor.layers.quantization.gguf import GGUFConfig, GGUFLinearMethod, is_layer_skipped_gguf, LinearBase, QuantizeMethodBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.gguf import ( + GGUFConfig, + GGUFLinearMethod, + UNQUANTIZED_TYPES, + is_layer_skipped_gguf, + LinearBase, + QuantizeMethodBase, + UnquantizedLinearMethod, +) from vllm import _custom_ops as ops from .base import DiffusionQuantizationConfig def dequant_gemm_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) From ae0824975ddc6ad228607b95b48da27030bbb678 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 16:52:22 +0800 Subject: [PATCH 31/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- docs/user_guide/diffusion/quantization/fp8.md | 2 +- .../text_to_image/text_to_image.py | 5 +---- vllm_omni/diffusion/data.py | 1 - .../model_loader/diffusers_loader.py | 6 ++--- .../model_loader/gguf_adapters/base.py | 8 ++----- .../model_loader/gguf_adapters/flux2_klein.py | 10 +++------ .../model_loader/gguf_adapters/qwen_image.py | 8 ++----- .../model_loader/gguf_adapters/z_image.py | 20 +++++------------ .../flux2_klein/pipeline_flux2_klein.py | 2 +- vllm_omni/diffusion/quantization/gguf.py | 22 ++++++------------- vllm_omni/entrypoints/cli/serve.py | 2 +- 11 files changed, 25 insertions(+), 61 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md index 3a3132677e..260543adc0 100644 --- a/docs/user_guide/diffusion/quantization/fp8.md +++ b/docs/user_guide/diffusion/quantization/fp8.md @@ -74,4 +74,4 @@ omni = Omni( cache_backend="tea_cache", cache_config={"rel_l1_thresh": 0.2}, ) -``` \ No newline at end of file +``` diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 6a1dea646d..36f9517f5d 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -124,10 +124,7 @@ def parse_args() -> argparse.Namespace: type=str, default=None, choices=["fp8", "gguf"], - help="Quantization method for the transformer. " - "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs), " - "'gguf' (load transformer weights from a GGUF file). " - "Default: None (no quantization, uses BF16).", + help=("GGUF file path or HF reference for transformer weights. Required when --quantization gguf is set."), ) parser.add_argument( "--gguf-model", diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index ac92248d4a..9fa349c96e 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -474,7 +474,6 @@ def __post_init__(self): if self.quantization is not None or self.quantization_config is not None: from vllm_omni.diffusion.quantization import ( DiffusionQuantizationConfig, - get_diffusion_quant_config, ) # Handle dict or DictConfig (from OmegaConf) - use Mapping for broader compatibility diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 4928993e62..5eebc04c1d 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -9,8 +9,8 @@ from typing import cast import torch -from torch import nn from huggingface_hub import hf_hub_download +from torch import nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger @@ -377,9 +377,7 @@ def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConf loadable_names = loadable_names or self._get_model_loadable_names(model) hf_iter = self._get_weights_iterator(source) hf_iter = ( - (name, tensor) - for (name, tensor) in hf_iter - if name in loadable_names and name not in loaded + (name, tensor) for (name, tensor) in hf_iter if name in loadable_names and name not in loaded ) loaded |= model.load_weights(hf_iter) else: diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py index 2eae17a1f4..690cac44db 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -30,9 +30,7 @@ def _build_gguf_name_map(self) -> dict[str, str]: try: import gguf # type: ignore except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError( - "GGUF support requires the 'gguf' package to be installed." - ) from exc + raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc def resolve_model_type() -> str: cfg = self.od_config.tf_model_config @@ -131,7 +129,5 @@ def split_name(name: str) -> tuple[str, str]: gguf_to_model_map[gguf_name] = name if not gguf_to_model_map: - raise RuntimeError( - f"No GGUF tensors were mapped for model_class_name={self.od_config.model_class_name!r}." - ) + raise RuntimeError(f"No GGUF tensors were mapped for model_class_name={self.od_config.model_class_name!r}.") return gguf_to_model_map diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index 0bf0ff2214..1ed545deaa 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -42,9 +42,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: try: import gguf # type: ignore except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError( - "GGUF support requires the 'gguf' package to be installed." - ) from exc + raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc reader = gguf.GGUFReader(self.gguf_file) allowed_names = self._build_allowed_names() @@ -55,16 +53,14 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: for mapped_tensor in self._map_tensor_name(tensor): if ( mapped_tensor.name not in allowed_names - and self._resolve_linear_qweight(mapped_tensor.name, param_names) - is None + and self._resolve_linear_qweight(mapped_tensor.name, param_names) is None ): continue mapped.append(mapped_tensor) if not mapped: raise RuntimeError( - "No GGUF tensors were mapped for Flux2 GGUF loader. " - "Please verify the GGUF file and model structure." + "No GGUF tensors were mapped for Flux2 GGUF loader. Please verify the GGUF file and model structure." ) for item in mapped: diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py index 76fe728355..3904976dc2 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py @@ -37,9 +37,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: try: import gguf # type: ignore except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError( - "GGUF support requires the 'gguf' package to be installed." - ) from exc + raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc reader = gguf.GGUFReader(self.gguf_file) gguf_name_map = self._build_gguf_name_map(reader) @@ -161,9 +159,7 @@ def _build_gguf_name_map(self, reader) -> dict[str, str]: try: import gguf # type: ignore except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError( - "GGUF support requires the 'gguf' package to be installed." - ) from exc + raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc gguf_tensor_names = {tensor.name for tensor in reader.tensors} diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 197a6e14f4..4571393613 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -38,9 +38,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: try: import gguf # type: ignore except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError( - "GGUF support requires the 'gguf' package to be installed." - ) from exc + raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc reader = gguf.GGUFReader(self.gguf_file) gguf_name_map = self._build_gguf_name_map(reader) @@ -50,23 +48,17 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: for tensor in reader.tensors: for mapped_tensor in self._map_tensor_name(tensor, gguf_name_map): - linear_qweight = self._resolve_linear_qweight( - mapped_tensor.name, param_names - ) + linear_qweight = self._resolve_linear_qweight(mapped_tensor.name, param_names) if mapped_tensor.name not in allowed_names and linear_qweight is None: continue - if ( - linear_qweight is None - and tensor.tensor_type.name not in ("F32", "BF16", "F16") - ): + if linear_qweight is None and tensor.tensor_type.name not in ("F32", "BF16", "F16"): # Skip quantized tensors that map to non-quantized parameters. continue mapped.append(mapped_tensor) if not mapped: raise RuntimeError( - "No GGUF tensors were mapped for Z-Image GGUF loader. " - "Please verify the GGUF file and model structure." + "No GGUF tensors were mapped for Z-Image GGUF loader. Please verify the GGUF file and model structure." ) for item in mapped: @@ -228,9 +220,7 @@ def _build_gguf_name_map(self, reader) -> dict[str, str]: try: import gguf # type: ignore except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError( - "GGUF support requires the 'gguf' package to be installed." - ) from exc + raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc gguf_tensor_names = {tensor.name for tensor in reader.tensors} diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 09df716ada..6fd2de3c94 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -42,8 +42,8 @@ from vllm_omni.diffusion.models.flux2_klein.flux2_klein_transformer import ( Flux2Transformer2DModel, ) -from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.models.interface import SupportImageInput +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import download_weights_from_hf_specific diff --git a/vllm_omni/diffusion/quantization/gguf.py b/vllm_omni/diffusion/quantization/gguf.py index 5472c3bb48..ee053b8c59 100644 --- a/vllm_omni/diffusion/quantization/gguf.py +++ b/vllm_omni/diffusion/quantization/gguf.py @@ -2,18 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """GGUF quantization config for diffusion transformers.""" -import torch import gguf +import torch +from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gguf import ( + UNQUANTIZED_TYPES, GGUFConfig, GGUFLinearMethod, - UNQUANTIZED_TYPES, - is_layer_skipped_gguf, LinearBase, QuantizeMethodBase, UnquantizedLinearMethod, + is_layer_skipped_gguf, ) -from vllm import _custom_ops as ops from .base import DiffusionQuantizationConfig @@ -47,11 +47,7 @@ def apply( for idx in shard_id: start, end, offset = layer.qweight.shard_offset_map[idx] qweight_type = layer.qweight_type.shard_weight_type[idx] - result.append( - dequant_gemm_gguf( - x, qweight[start:end, :offset].contiguous(), qweight_type - ) - ) + result.append(dequant_gemm_gguf(x, qweight[start:end, :offset].contiguous(), qweight_type)) out = torch.cat(result, axis=-1) else: qweight = layer.qweight @@ -63,13 +59,9 @@ def apply( class _GGUFConfig(GGUFConfig): - def get_quant_method( - self, layer: torch.nn.Module, prefix: str - ) -> "QuantizeMethodBase": + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> "QuantizeMethodBase": if isinstance(layer, LinearBase): - if is_layer_skipped_gguf( - prefix, self.unquantized_modules, self.packed_modules_mapping - ): + if is_layer_skipped_gguf(prefix, self.unquantized_modules, self.packed_modules_mapping): return UnquantizedLinearMethod() return DiffusionGGUFLinearMethod(self) return None diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 4904781945..b1fb048fbb 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -168,7 +168,7 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help=( "JSON string for diffusion quantization_config. " - "Example: '{\"method\":\"gguf\",\"gguf_model\":\"/path/to/model.gguf\"}'." + 'Example: \'{"method":"gguf","gguf_model":"/path/to/model.gguf"}\'.' ), ) From 03217e0e145decdb5c59c046bcc3d43d83fc2549 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 16:56:34 +0800 Subject: [PATCH 32/62] simple doc Signed-off-by: David Chen <530634352@qq.com> --- .../user_guide/diffusion/quantization/gguf.md | 51 ------------------- 1 file changed, 51 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/gguf.md b/docs/user_guide/diffusion/quantization/gguf.md index 3b18a31be0..17d4301b96 100644 --- a/docs/user_guide/diffusion/quantization/gguf.md +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -1,7 +1,5 @@ # Diffusion Quantization Design: Native GGUF -Date: 2026-02-12 - ## Goals 1. Reuse vLLM quantization configs and weight loaders as much as possible. 2. Add native GGUF support to diffusion transformers without changing model definitions. @@ -134,50 +132,6 @@ Adapter paths: - Z-Image: `vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py` - Flux2-Klein: `vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py` -## Flux2-Klein GGUF Mapping (Key Rules) -1. **Core rename (diffusers-compatible)**: - - `img_in` -> `x_embedder` - - `txt_in` -> `context_embedder` - - `time_in.*` -> `time_guidance_embed.timestep_embedder.*` - - `guidance_in.*` -> `time_guidance_embed.guidance_embedder.*` - - `double_stream_modulation_*` -> `double_stream_modulation_*.linear` - - `single_stream_modulation.lin` -> `single_stream_modulation.linear` - - `final_layer.linear` -> `proj_out` -2. **Double blocks (img/txt)**: - - `double_blocks.{i}.img_attn.qkv.weight` - -> `transformer_blocks.{i}.attn.to_q/to_k/to_v` - - `double_blocks.{i}.txt_attn.qkv.weight` - -> `transformer_blocks.{i}.attn.add_q_proj/add_k_proj/add_v_proj` - - Other mappings: - - `img_attn.norm.query_norm` -> `attn.norm_q` - - `img_attn.norm.key_norm` -> `attn.norm_k` - - `img_attn.proj` -> `attn.to_out.0` - - `img_mlp.0` -> `ff.linear_in` - - `img_mlp.2` -> `ff.linear_out` - - `txt_attn.norm.query_norm` -> `attn.norm_added_q` - - `txt_attn.norm.key_norm` -> `attn.norm_added_k` - - `txt_attn.proj` -> `attn.to_add_out` - - `txt_mlp.0` -> `ff_context.linear_in` - - `txt_mlp.2` -> `ff_context.linear_out` -3. **Single blocks**: - - `single_blocks.{i}.linear1` -> `single_transformer_blocks.{i}.attn.to_qkv_mlp_proj` - - `single_blocks.{i}.linear2` -> `single_transformer_blocks.{i}.attn.to_out` - - `single_blocks.{i}.norm.query_norm` -> `single_transformer_blocks.{i}.attn.norm_q` - - `single_blocks.{i}.norm.key_norm` -> `single_transformer_blocks.{i}.attn.norm_k` -4. **AdaLN swap**: - - `final_layer.adaLN_modulation.1.weight` -> `norm_out.linear.weight` with (shift, scale) swapped. - -## Flux2-Klein GGUF Loader Logic -1. **Iterator flow**: - - Read GGUF tensors via `gguf.GGUFReader`. - - Apply Flux2-Klein mapping rules to produce diffusers-style names. - - For QKV tensors, split along dim0 into Q/K/V shards. -2. **Linear weights go to `qweight`** (both quantized and BF16/F16): - - Always emit `qweight_type` for linear weights. - - Use shard names (`to_q.qweight`, `to_k.qweight`, `to_v.qweight`) so vLLM can reassemble into `to_qkv.qweight`. -3. **Non-linear weights** (norm/bias/scale) keep `.weight`/`.bias` names. -4. **Remaining HF weights** are loaded after GGUF to fill gaps. - ## User Usage (Offline) ### Baseline BF16 @@ -235,8 +189,3 @@ curl -X POST http://localhost:8000/v1/images/generations \ "num_inference_steps": 4 }' ``` - -## Validation Checklist -1. Fix the date in logs and docs for comparisons. -2. Use the same prompt, size, steps, and seed for BF16 vs GGUF comparisons. -3. Expect accuracy differences for Q8_0 GGUF; verify mapping with F16/BF16 GGUF if needed. From d3ab484f16aa9f6f4103d9ef39f01243ef1c121e Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 17:01:33 +0800 Subject: [PATCH 33/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- docs/user_guide/diffusion/quantization/gguf.md | 6 +----- examples/offline_inference/text_to_image/text_to_image.py | 5 +---- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/gguf.md b/docs/user_guide/diffusion/quantization/gguf.md index 17d4301b96..bef992f5ec 100644 --- a/docs/user_guide/diffusion/quantization/gguf.md +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -1,4 +1,4 @@ -# Diffusion Quantization Design: Native GGUF +# GGUF Quantization ## Goals 1. Reuse vLLM quantization configs and weight loaders as much as possible. @@ -108,10 +108,6 @@ ops.ggml_dequantize x @ weight.T ``` -Notes: -1. GGUF linear inputs are flattened to 2D inside `GGUFLinearMethod.apply` and reshaped back. -2. As of 2026-02-10 in this branch, `_fused_mul_mat_gguf` is forced to the dequantize path. - ## GGUF Weight Loading Path (Transformer-Only) 1. `DiffusersPipelineLoader.load_model` detects `quantization_config.method == "gguf"`. 2. `gguf_model` is resolved as one of: local file, URL, `repo/file.gguf`, or `repo:quant_type`. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 36f9517f5d..ab0354ede2 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -130,10 +130,7 @@ def parse_args() -> argparse.Namespace: "--gguf-model", type=str, default=None, - help=( - "GGUF file path or HF reference for transformer weights. " - "Required when --quantization gguf is set." - ), + help=("GGUF file path or HF reference for transformer weights. Required when --quantization gguf is set."), ) parser.add_argument( "--ignored-layers", From 43dc33a0269559d08987ac6912c558fd6d644310 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 12 Feb 2026 17:03:04 +0800 Subject: [PATCH 34/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/gguf_adapters/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py index 690cac44db..3928ccabf2 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -4,7 +4,6 @@ from collections.abc import Generator import torch - from vllm.model_executor.model_loader.weight_utils import gguf_quant_weights_iterator From fb43b4fb318afc6196762df4926c554192a0e07c Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 13 Feb 2026 09:26:57 +0800 Subject: [PATCH 35/62] fix comment 1 Signed-off-by: David Chen <530634352@qq.com> --- docs/user_guide/diffusion/quantization/fp8.md | 2 +- .../text_to_image/text_to_image.py | 6 ++- .../model_loader/diffusers_loader.py | 45 ++++++++++++++++++- .../model_loader/gguf_adapters/__init__.py | 18 +++++++- .../model_loader/gguf_adapters/z_image.py | 2 + vllm_omni/diffusion/quantization/gguf.py | 2 +- vllm_omni/engine/arg_utils.py | 2 + 7 files changed, 71 insertions(+), 6 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md index 260543adc0..2c065546bf 100644 --- a/docs/user_guide/diffusion/quantization/fp8.md +++ b/docs/user_guide/diffusion/quantization/fp8.md @@ -1,4 +1,4 @@ -# FP8 Quantization +# FP8 Quantization ## Overview diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index ab0354ede2..2051fdd412 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -124,7 +124,11 @@ def parse_args() -> argparse.Namespace: type=str, default=None, choices=["fp8", "gguf"], - help=("GGUF file path or HF reference for transformer weights. Required when --quantization gguf is set."), + help=( + "Quantization method for the transformer. " + "Options: 'fp8' (FP8 W8A8), 'gguf' (GGUF quantized weights). " + "Default: None (no quantization, uses BF16)." + ), ) parser.add_argument( "--gguf-model", diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 5eebc04c1d..b0d54a0601 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -3,7 +3,11 @@ import dataclasses import glob import os +import shutil +import tempfile import time +from urllib.parse import urlparse +from urllib.request import urlopen from collections.abc import Generator, Iterable from pathlib import Path from typing import cast @@ -292,7 +296,7 @@ def _is_gguf_quantization(self, od_config: OmniDiffusionConfig) -> bool: # Normal path: DiffusionQuantizationConfig try: is_gguf = quant_config.get_name() == "gguf" - except Exception: + except AttributeError: # Fallback: if it carries gguf_model, treat as GGUF gguf_model = getattr(quant_config, "gguf_model", None) return bool(gguf_model) @@ -320,7 +324,7 @@ def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str return gguf_model # raw HTTPS link if gguf_model.startswith(("http://", "https://")) and gguf_model.endswith(".gguf"): - return hf_hub_download(url=gguf_model) + return self._download_raw_gguf_url(gguf_model) # repo_id/filename.gguf if "/" in gguf_model and gguf_model.endswith(".gguf"): repo_id, filename = gguf_model.rsplit("/", 1) @@ -345,6 +349,43 @@ def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str "raw URL, /.gguf, or :)" ) + def _download_raw_gguf_url(self, url: str) -> str: + parsed = urlparse(url) + filename = os.path.basename(parsed.path) + if not filename: + raise ValueError(f"Cannot infer GGUF filename from URL: {url!r}") + + cache_dir = self.load_config.download_dir + if cache_dir is None: + cache_dir = os.path.join( + os.path.expanduser("~"), + ".cache", + "vllm-omni", + "gguf", + ) + os.makedirs(cache_dir, exist_ok=True) + target_path = os.path.join(cache_dir, filename) + if os.path.exists(target_path): + return target_path + + tmp_fd, tmp_path = tempfile.mkstemp( + suffix=".gguf", + prefix="gguf-", + dir=cache_dir, + ) + os.close(tmp_fd) + try: + with urlopen(url) as response, open(tmp_path, "wb") as out_file: + shutil.copyfileobj(response, out_file) + os.replace(tmp_path, target_path) + except Exception: + try: + os.remove(tmp_path) + except OSError: + pass + raise + return target_path + def _get_gguf_weights_iterator( self, source: "ComponentSource", diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py index 7abceda31b..aa2cea482a 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -1,13 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from typing import TYPE_CHECKING + +import torch + from .base import GGUFAdapter from .flux2_klein import Flux2KleinGGUFAdapter from .qwen_image import QwenImageGGUFAdapter from .z_image import ZImageGGUFAdapter -def get_gguf_adapter(gguf_file: str, model, source, od_config) -> GGUFAdapter: +if TYPE_CHECKING: + from vllm_omni.diffusion.data import OmniDiffusionConfig + from vllm_omni.diffusion.model_loader.diffusers_loader import ( + DiffusersPipelineLoader, + ) + + +def get_gguf_adapter( + gguf_file: str, + model: torch.nn.Module, + source: "DiffusersPipelineLoader.ComponentSource", + od_config: "OmniDiffusionConfig", +) -> GGUFAdapter: for adapter_cls in (QwenImageGGUFAdapter, ZImageGGUFAdapter, Flux2KleinGGUFAdapter): if adapter_cls.is_compatible(od_config, model, source): return adapter_cls(gguf_file, model, source, od_config) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 4571393613..9850ca9507 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -104,6 +104,8 @@ def _get_patch_key(self) -> str: if hasattr(target, "all_x_embedder"): keys = list(getattr(target, "all_x_embedder").keys()) if "2-1" in keys: + # Default to the standard Z-Image Turbo patch/frequency config + # (patch_size=2, f_patch_size=1) when available. return "2-1" if keys: return sorted(keys)[0] diff --git a/vllm_omni/diffusion/quantization/gguf.py b/vllm_omni/diffusion/quantization/gguf.py index ee053b8c59..85ec2ee33a 100644 --- a/vllm_omni/diffusion/quantization/gguf.py +++ b/vllm_omni/diffusion/quantization/gguf.py @@ -37,7 +37,7 @@ def apply( # Dequantize + GEMM path: torch.matmul multiplies over the last # dimension and broadcasts leading dimensions, so no 2D flattening # is required here. - shard_id = layer.qweight.shard_id + shard_id = getattr(layer.qweight, "shard_id", None) if shard_id: # dequantize shard weights respectively diff --git a/vllm_omni/engine/arg_utils.py b/vllm_omni/engine/arg_utils.py index d550cf5554..c28dfaa5c6 100644 --- a/vllm_omni/engine/arg_utils.py +++ b/vllm_omni/engine/arg_utils.py @@ -74,6 +74,7 @@ class OmniEngineArgs(EngineArgs): stage_connector_spec: dict[str, Any] = field(default_factory=dict) async_chunk: bool = False omni_kv_config: dict | None = None + quantization_config: Any | None = None def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: # transformers' get_text_config method is used to get the text config from thinker_config. @@ -176,6 +177,7 @@ class AsyncOmniEngineArgs(AsyncEngineArgs): stage_connector_spec: dict[str, Any] = field(default_factory=dict) async_chunk: bool = False omni_kv_config: dict | None = None + quantization_config: Any | None = None def draw_hf_text_config(self, config_dict: dict) -> Qwen3OmniMoeTextConfig: # transformers' get_text_config method is used to get the text config from thinker_config. From 45fac53d548ee5b70dc8ceca3250fb0db10cb942 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 13 Feb 2026 09:32:43 +0800 Subject: [PATCH 36/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/diffusers_loader.py | 4 ++-- vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index b0d54a0601..428f0df41c 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -6,11 +6,11 @@ import shutil import tempfile import time -from urllib.parse import urlparse -from urllib.request import urlopen from collections.abc import Generator, Iterable from pathlib import Path from typing import cast +from urllib.parse import urlparse +from urllib.request import urlopen import torch from huggingface_hub import hf_hub_download diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py index aa2cea482a..770158de78 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -10,7 +10,6 @@ from .qwen_image import QwenImageGGUFAdapter from .z_image import ZImageGGUFAdapter - if TYPE_CHECKING: from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.model_loader.diffusers_loader import ( @@ -21,8 +20,8 @@ def get_gguf_adapter( gguf_file: str, model: torch.nn.Module, - source: "DiffusersPipelineLoader.ComponentSource", - od_config: "OmniDiffusionConfig", + source: DiffusersPipelineLoader.ComponentSource, + od_config: OmniDiffusionConfig, ) -> GGUFAdapter: for adapter_cls in (QwenImageGGUFAdapter, ZImageGGUFAdapter, Flux2KleinGGUFAdapter): if adapter_cls.is_compatible(od_config, model, source): From e5a70d4e554f6b3214bcd77b681a57e059ee71ec Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 13 Feb 2026 09:48:37 +0800 Subject: [PATCH 37/62] fix comment2 Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/gguf_adapters/base.py | 75 +++++++++++++++ .../model_loader/gguf_adapters/flux2_klein.py | 95 +++++-------------- .../model_loader/gguf_adapters/qwen_image.py | 78 +++------------ .../model_loader/gguf_adapters/z_image.py | 87 ++++------------- 4 files changed, 132 insertions(+), 203 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py index 3928ccabf2..4523921b69 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -2,14 +2,32 @@ from __future__ import annotations from collections.abc import Generator +from dataclasses import dataclass +from typing import Any import torch from vllm.model_executor.model_loader.weight_utils import gguf_quant_weights_iterator +@dataclass +class MappedTensor: + name: str + tensor: Any + tensor_type: Any + row_slice: slice | None = None + swap_scale_shift: bool = False + + class GGUFAdapter: """Default GGUF adapter using gguf-py's tensor name mapping.""" + _include_qkv_virtuals: bool = False + _include_add_kv_proj_virtuals: bool = False + _include_to_out_virtuals: bool = False + _include_w13_virtuals: bool = False + _shard_tokens: tuple[str, ...] = () + _prefer_exact_qweight: bool = True + def __init__(self, gguf_file: str, model: torch.nn.Module, source, od_config) -> None: self.gguf_file = gguf_file self.model = model @@ -25,6 +43,63 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: name_map = self._build_gguf_name_map() return gguf_quant_weights_iterator(self.gguf_file, name_map) + def _get_target_module(self) -> torch.nn.Module: + prefix = getattr(self.source, "prefix", "") + return self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model + + def _build_allowed_names(self) -> set[str]: + target = self._get_target_module() + allowed = {name for name, _ in target.named_parameters()} + allowed.update(name for name, _ in target.named_buffers()) + for name in list(allowed): + if name.endswith(".qweight"): + allowed.add(name.replace(".qweight", ".weight")) + elif name.endswith(".qweight_type"): + allowed.add(name.replace(".qweight_type", ".weight")) + + virtual_names = set() + for name in allowed: + if self._include_qkv_virtuals and ".to_qkv." in name: + virtual_names.add(name.replace(".to_qkv.", ".to_q.")) + virtual_names.add(name.replace(".to_qkv.", ".to_k.")) + virtual_names.add(name.replace(".to_qkv.", ".to_v.")) + if self._include_add_kv_proj_virtuals and ".add_kv_proj." in name: + virtual_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) + virtual_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) + if self._include_w13_virtuals and ".w13." in name: + virtual_names.add(name.replace(".w13.", ".w1.")) + virtual_names.add(name.replace(".w13.", ".w3.")) + if self._include_to_out_virtuals and ".to_out." in name: + virtual_names.add(name.replace(".to_out.", ".to_out.0.")) + allowed.update(virtual_names) + return allowed + + def _build_param_names(self) -> set[str]: + target = self._get_target_module() + return {name for name, _ in target.named_parameters()} + + def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: + if not name.endswith(".weight"): + return None + if self._prefer_exact_qweight: + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + if ".to_out.0." in name: + alt_name = name.replace(".to_out.0.", ".to_out.") + candidate = alt_name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + name = alt_name + for shard_token in self._shard_tokens: + if shard_token in name: + return name.replace(".weight", ".qweight") + candidate = name.replace(".weight", ".qweight") + if candidate in param_names: + return candidate + return None + def _build_gguf_name_map(self) -> dict[str, str]: try: import gguf # type: ignore diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index 1ed545deaa..001766ac23 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -2,27 +2,28 @@ from __future__ import annotations from collections.abc import Generator -from dataclasses import dataclass -from typing import Any import numpy as np import torch -from .base import GGUFAdapter - - -@dataclass -class _MappedTensor: - name: str - tensor: Any - tensor_type: Any - row_slice: slice | None = None - swap_scale_shift: bool = False +from .base import GGUFAdapter, MappedTensor class Flux2KleinGGUFAdapter(GGUFAdapter): """GGUF adapter for Flux2-Klein models with qkv splitting and adaLN swap.""" + _include_qkv_virtuals = True + _include_add_kv_proj_virtuals = True + _include_to_out_virtuals = True + _shard_tokens = ( + ".to_q.", + ".to_k.", + ".to_v.", + ".add_q_proj.", + ".add_k_proj.", + ".add_v_proj.", + ) + @staticmethod def is_compatible(od_config, model: torch.nn.Module, source) -> bool: model_class = od_config.model_class_name or "" @@ -47,7 +48,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: reader = gguf.GGUFReader(self.gguf_file) allowed_names = self._build_allowed_names() param_names = self._build_param_names() - mapped: list[_MappedTensor] = [] + mapped: list[MappedTensor] = [] for tensor in reader.tensors: for mapped_tensor in self._map_tensor_name(tensor): @@ -97,55 +98,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: yield name, param - def _build_allowed_names(self) -> set[str]: - prefix = getattr(self.source, "prefix", "") - target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model - allowed = {name for name, _ in target.named_parameters()} - allowed.update(name for name, _ in target.named_buffers()) - for name in list(allowed): - if name.endswith(".qweight"): - allowed.add(name.replace(".qweight", ".weight")) - elif name.endswith(".qweight_type"): - allowed.add(name.replace(".qweight_type", ".weight")) - - virtual_names = set() - for name in allowed: - if ".to_qkv." in name: - virtual_names.add(name.replace(".to_qkv.", ".to_q.")) - virtual_names.add(name.replace(".to_qkv.", ".to_k.")) - virtual_names.add(name.replace(".to_qkv.", ".to_v.")) - if ".add_kv_proj." in name: - virtual_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) - virtual_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) - virtual_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) - allowed.update(virtual_names) - return allowed - - def _build_param_names(self) -> set[str]: - prefix = getattr(self.source, "prefix", "") - target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model - return {name for name, _ in target.named_parameters()} - - def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: - if not name.endswith(".weight"): - return None - # Keep QKV shard names so load_weights can attach shard_id correctly. - for shard_token in ( - ".to_q.", - ".to_k.", - ".to_v.", - ".add_q_proj.", - ".add_k_proj.", - ".add_v_proj.", - ): - if shard_token in name: - return name.replace(".weight", ".qweight") - candidate = name.replace(".weight", ".qweight") - if candidate in param_names: - return candidate - return None - - def _map_tensor_name(self, tensor) -> list[_MappedTensor]: + def _map_tensor_name(self, tensor) -> list[MappedTensor]: name = tensor.name if name.startswith("double_blocks."): @@ -154,7 +107,7 @@ def _map_tensor_name(self, tensor) -> list[_MappedTensor]: return self._map_single_blocks(tensor) if name.startswith("final_layer.adaLN_modulation.1") and name.endswith(".weight"): return [ - _MappedTensor( + MappedTensor( name="norm_out.linear.weight", tensor=tensor, tensor_type=tensor.tensor_type, @@ -166,14 +119,14 @@ def _map_tensor_name(self, tensor) -> list[_MappedTensor]: name = name.replace(src, dst) return [ - _MappedTensor( + MappedTensor( name=name, tensor=tensor, tensor_type=tensor.tensor_type, ) ] - def _map_double_blocks(self, tensor) -> list[_MappedTensor]: + def _map_double_blocks(self, tensor) -> list[MappedTensor]: name = tensor.name parts = name.split(".") block_idx = parts[1] @@ -198,18 +151,18 @@ def _map_double_blocks(self, tensor) -> list[_MappedTensor]: dim0 = weight.shape[0] split = dim0 // 3 return [ - _MappedTensor(q_name, tensor, tensor.tensor_type, slice(0, split)), - _MappedTensor(k_name, tensor, tensor.tensor_type, slice(split, 2 * split)), - _MappedTensor(v_name, tensor, tensor.tensor_type, slice(2 * split, 3 * split)), + MappedTensor(q_name, tensor, tensor.tensor_type, slice(0, split)), + MappedTensor(k_name, tensor, tensor.tensor_type, slice(split, 2 * split)), + MappedTensor(v_name, tensor, tensor.tensor_type, slice(2 * split, 3 * split)), ] mapped_name = _FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP.get(within_block_name) if mapped_name is None: return [] target = f"transformer_blocks.{block_idx}.{mapped_name}.{param_type}" - return [_MappedTensor(target, tensor, tensor.tensor_type)] + return [MappedTensor(target, tensor, tensor.tensor_type)] - def _map_single_blocks(self, tensor) -> list[_MappedTensor]: + def _map_single_blocks(self, tensor) -> list[MappedTensor]: name = tensor.name parts = name.split(".") block_idx = parts[1] @@ -222,7 +175,7 @@ def _map_single_blocks(self, tensor) -> list[_MappedTensor]: if mapped_name is None: return [] target = f"single_transformer_blocks.{block_idx}.{mapped_name}.{param_type}" - return [_MappedTensor(target, tensor, tensor.tensor_type)] + return [MappedTensor(target, tensor, tensor.tensor_type)] _FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py index 3904976dc2..6b5e2cbabc 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py @@ -2,25 +2,28 @@ from __future__ import annotations from collections.abc import Generator -from dataclasses import dataclass -from typing import Any import numpy as np import torch -from .base import GGUFAdapter - - -@dataclass -class _MappedTensor: - name: str - tensor: Any - tensor_type: Any +from .base import GGUFAdapter, MappedTensor class QwenImageGGUFAdapter(GGUFAdapter): """GGUF adapter for Qwen-Image models with QKV shard support.""" + _include_qkv_virtuals = True + _include_add_kv_proj_virtuals = True + _include_to_out_virtuals = True + _shard_tokens = ( + ".to_q.", + ".to_k.", + ".to_v.", + ".add_q_proj.", + ".add_k_proj.", + ".add_v_proj.", + ) + @staticmethod def is_compatible(od_config, model: torch.nn.Module, source) -> bool: model_class = od_config.model_class_name or "" @@ -43,7 +46,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: gguf_name_map = self._build_gguf_name_map(reader) allowed_names = self._build_allowed_names() param_names = self._build_param_names() - mapped: list[_MappedTensor] = [] + mapped: list[MappedTensor] = [] for tensor in reader.tensors: mapped_name = gguf_name_map.get(tensor.name) @@ -56,7 +59,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: # Skip quantized tensors that map to non-quantized parameters. continue mapped.append( - _MappedTensor( + MappedTensor( name=mapped_name, tensor=tensor, tensor_type=tensor.tensor_type, @@ -104,57 +107,6 @@ def _normalize_name(self, name: str) -> str: name = name.replace(".to_out.0.", ".to_out.") return name - def _build_allowed_names(self) -> set[str]: - prefix = getattr(self.source, "prefix", "") - target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model - allowed = {name for name, _ in target.named_parameters()} - allowed.update(name for name, _ in target.named_buffers()) - for name in list(allowed): - if name.endswith(".qweight"): - allowed.add(name.replace(".qweight", ".weight")) - elif name.endswith(".qweight_type"): - allowed.add(name.replace(".qweight_type", ".weight")) - - virtual_names = set() - for name in allowed: - if ".to_qkv." in name: - virtual_names.add(name.replace(".to_qkv.", ".to_q.")) - virtual_names.add(name.replace(".to_qkv.", ".to_k.")) - virtual_names.add(name.replace(".to_qkv.", ".to_v.")) - if ".add_kv_proj." in name: - virtual_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) - virtual_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) - virtual_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) - if ".to_out." in name: - virtual_names.add(name.replace(".to_out.", ".to_out.0.")) - allowed.update(virtual_names) - return allowed - - def _build_param_names(self) -> set[str]: - prefix = getattr(self.source, "prefix", "") - target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model - return {name for name, _ in target.named_parameters()} - - def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: - if not name.endswith(".weight"): - return None - if ".to_out.0." in name: - name = name.replace(".to_out.0.", ".to_out.") - for shard_token in ( - ".to_q.", - ".to_k.", - ".to_v.", - ".add_q_proj.", - ".add_k_proj.", - ".add_v_proj.", - ): - if shard_token in name: - return name.replace(".weight", ".qweight") - candidate = name.replace(".weight", ".qweight") - if candidate in param_names: - return candidate - return None - def _build_gguf_name_map(self, reader) -> dict[str, str]: try: import gguf # type: ignore diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 9850ca9507..412d78ce59 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -2,26 +2,27 @@ from __future__ import annotations from collections.abc import Generator -from dataclasses import dataclass -from typing import Any import numpy as np import torch -from .base import GGUFAdapter - - -@dataclass -class _MappedTensor: - name: str - tensor: Any - tensor_type: Any - row_slice: slice | None = None +from .base import GGUFAdapter, MappedTensor class ZImageGGUFAdapter(GGUFAdapter): """GGUF adapter for Z-Image models with QKV/FFN shard support.""" + _include_qkv_virtuals = True + _include_to_out_virtuals = True + _include_w13_virtuals = True + _shard_tokens = ( + ".to_q.", + ".to_k.", + ".to_v.", + ".w1.", + ".w3.", + ) + @staticmethod def is_compatible(od_config, model: torch.nn.Module, source) -> bool: model_class = od_config.model_class_name or "" @@ -44,7 +45,7 @@ def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: gguf_name_map = self._build_gguf_name_map(reader) allowed_names = self._build_allowed_names() param_names = self._build_param_names() - mapped: list[_MappedTensor] = [] + mapped: list[MappedTensor] = [] for tensor in reader.tensors: for mapped_tensor in self._map_tensor_name(tensor, gguf_name_map): @@ -127,7 +128,7 @@ def _apply_zimage_renames(self, name: str) -> str: name = name.replace(".attention.q_norm.weight", ".attention.norm_q.weight") return name - def _map_tensor_name(self, tensor, gguf_name_map: dict[str, str]) -> list[_MappedTensor]: + def _map_tensor_name(self, tensor, gguf_name_map: dict[str, str]) -> list[MappedTensor]: name = gguf_name_map.get(tensor.name) if name is None: name = self._normalize_name(tensor.name) @@ -138,19 +139,19 @@ def _map_tensor_name(self, tensor, gguf_name_map: dict[str, str]) -> list[_Mappe dim0 = weight.shape[0] split = dim0 // 3 return [ - _MappedTensor( + MappedTensor( name=name.replace(".attention.qkv.weight", ".attention.to_q.weight"), tensor=tensor, tensor_type=tensor.tensor_type, row_slice=slice(0, split), ), - _MappedTensor( + MappedTensor( name=name.replace(".attention.qkv.weight", ".attention.to_k.weight"), tensor=tensor, tensor_type=tensor.tensor_type, row_slice=slice(split, 2 * split), ), - _MappedTensor( + MappedTensor( name=name.replace(".attention.qkv.weight", ".attention.to_v.weight"), tensor=tensor, tensor_type=tensor.tensor_type, @@ -159,65 +160,13 @@ def _map_tensor_name(self, tensor, gguf_name_map: dict[str, str]) -> list[_Mappe ] return [ - _MappedTensor( + MappedTensor( name=name, tensor=tensor, tensor_type=tensor.tensor_type, ) ] - def _build_allowed_names(self) -> set[str]: - prefix = getattr(self.source, "prefix", "") - target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model - allowed = {name for name, _ in target.named_parameters()} - allowed.update(name for name, _ in target.named_buffers()) - for name in list(allowed): - if name.endswith(".qweight"): - allowed.add(name.replace(".qweight", ".weight")) - elif name.endswith(".qweight_type"): - allowed.add(name.replace(".qweight_type", ".weight")) - - virtual_names = set() - for name in allowed: - if ".to_qkv." in name: - virtual_names.add(name.replace(".to_qkv.", ".to_q.")) - virtual_names.add(name.replace(".to_qkv.", ".to_k.")) - virtual_names.add(name.replace(".to_qkv.", ".to_v.")) - if ".w13." in name: - virtual_names.add(name.replace(".w13.", ".w1.")) - virtual_names.add(name.replace(".w13.", ".w3.")) - if ".to_out." in name: - virtual_names.add(name.replace(".to_out.", ".to_out.0.")) - allowed.update(virtual_names) - return allowed - - def _build_param_names(self) -> set[str]: - prefix = getattr(self.source, "prefix", "") - target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model - return {name for name, _ in target.named_parameters()} - - def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | None: - if not name.endswith(".weight"): - return None - candidate = name.replace(".weight", ".qweight") - if candidate in param_names: - return candidate - if ".to_out.0." in name: - alt_name = name.replace(".to_out.0.", ".to_out.") - candidate = alt_name.replace(".weight", ".qweight") - if candidate in param_names: - return candidate - for shard_token in ( - ".to_q.", - ".to_k.", - ".to_v.", - ".w1.", - ".w3.", - ): - if shard_token in name: - return name.replace(".weight", ".qweight") - return None - def _build_gguf_name_map(self, reader) -> dict[str, str]: try: import gguf # type: ignore From 5906200d700f25d7535cb92bfc2d40fceacc2113 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 13 Feb 2026 09:53:40 +0800 Subject: [PATCH 38/62] fix bug Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/quantization/base.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/quantization/base.py b/vllm_omni/diffusion/quantization/base.py index 0cd9e4147e..d18d970743 100644 --- a/vllm_omni/diffusion/quantization/base.py +++ b/vllm_omni/diffusion/quantization/base.py @@ -31,15 +31,16 @@ class DiffusionQuantizationConfig(ABC): # The underlying vLLM config instance _vllm_config: "QuantizationConfig | None" = None - @classmethod - def get_name(cls) -> str: + def get_name(self) -> str: """Return the quantization method name (e.g., 'fp8', 'int8'). - By default, delegates to the underlying vLLM config class. + By default, delegates to the underlying vLLM config instance. """ - if cls.quant_config_cls is not None: - return cls.quant_config_cls.get_name() - raise NotImplementedError("Subclass must set quant_config_cls or override get_name()") + if self._vllm_config is not None: + return self._vllm_config.get_name() + raise NotImplementedError( + "Subclass must initialize _vllm_config or override get_name()." + ) def get_vllm_quant_config(self) -> "QuantizationConfig | None": """Return the underlying vLLM QuantizationConfig for linear layers.""" From 20307f6de21eca6cd3d001ca662147015dc11edc Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 13 Feb 2026 10:19:19 +0800 Subject: [PATCH 39/62] fix comment 2 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 412d78ce59..6380cef611 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -224,6 +224,10 @@ def split_name(name: str) -> tuple[str, str]: return name, "" model_type = resolve_model_type() + if model_type in {"z_image", "zimage", "z-image"}: + # gguf-py does not register a Z-Image architecture, so we rely on + # direct tensor names from the GGUF file. + return {} try: arch = resolve_arch(model_type) except RuntimeError: From b35870e5365416af9327b6e5a65b8988e4c633be Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 13 Feb 2026 10:23:07 +0800 Subject: [PATCH 40/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/quantization/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/quantization/base.py b/vllm_omni/diffusion/quantization/base.py index d18d970743..17e6d32ead 100644 --- a/vllm_omni/diffusion/quantization/base.py +++ b/vllm_omni/diffusion/quantization/base.py @@ -38,9 +38,7 @@ def get_name(self) -> str: """ if self._vllm_config is not None: return self._vllm_config.get_name() - raise NotImplementedError( - "Subclass must initialize _vllm_config or override get_name()." - ) + raise NotImplementedError("Subclass must initialize _vllm_config or override get_name().") def get_vllm_quant_config(self) -> "QuantizationConfig | None": """Return the underlying vLLM QuantizationConfig for linear layers.""" From c9fceb9d5e4f3f1113b60e7fb5bc19d7ed1ef5a3 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 13 Feb 2026 10:39:41 +0800 Subject: [PATCH 41/62] add doc Signed-off-by: David Chen <530634352@qq.com> --- docs/.nav.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/.nav.yml b/docs/.nav.yml index 07db1b4651..f4144b38e5 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -46,6 +46,7 @@ nav: - Quantization: - Overview: user_guide/diffusion/quantization/overview.md - FP8: user_guide/diffusion/quantization/fp8.md + - GGUF: user_guide/diffusion/quantization/gguf.md - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md - ComfyUI: features/comfyui.md From d1550cf065ef1045b287644b14238dc3d6f27bdb Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 18 Feb 2026 22:30:48 +0800 Subject: [PATCH 42/62] draft Signed-off-by: Isotr0py --- .../model_loader/gguf_adapters/flux2_klein.py | 78 +++++++++++++------ 1 file changed, 53 insertions(+), 25 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index 001766ac23..17cd167d0f 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -6,39 +6,67 @@ import numpy as np import torch +from vllm.model_executor.models.utils import WeightsMapper + from .base import GGUFAdapter, MappedTensor +FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { + # Image and text input projections + "img_in": "x_embedder", + "txt_in": "context_embedder", + # Timestep and guidance embeddings + "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", + "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", + "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", + "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", + # Modulation parameters + "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", + "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", + "single_stream_modulation.lin": "single_stream_modulation.linear", + # Final output layer + # "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params + "final_layer.linear": "proj_out", +} + +FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = { + "final_layer.adaLN_modulation.1": "norm_out.linear", +} + +FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { + # Handle fused QKV projections separately as we need to break into Q, K, V projections + "img_attn.norm.query_norm": "attn.norm_q", + "img_attn.norm.key_norm": "attn.norm_k", + "img_attn.proj": "attn.to_out.0", + "img_mlp.0": "ff.linear_in", + "img_mlp.2": "ff.linear_out", + "txt_attn.norm.query_norm": "attn.norm_added_q", + "txt_attn.norm.key_norm": "attn.norm_added_k", + "txt_attn.proj": "attn.to_add_out", + "txt_mlp.0": "ff_context.linear_in", + "txt_mlp.2": "ff_context.linear_out", + # Additional for fuse qkv + "img_attn.qkv": "attn.to_qkv_mlp_proj", + "txt_attn.qkv": "attn.add_kv_proj", +} + +FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { + "linear1": "attn.to_qkv_mlp_proj", + "linear2": "attn.to_out", + "norm.query_norm": "attn.norm_q", + "norm.key_norm": "attn.norm_k", +} + + class Flux2KleinGGUFAdapter(GGUFAdapter): """GGUF adapter for Flux2-Klein models with qkv splitting and adaLN swap.""" - _include_qkv_virtuals = True - _include_add_kv_proj_virtuals = True - _include_to_out_virtuals = True - _shard_tokens = ( - ".to_q.", - ".to_k.", - ".to_v.", - ".add_q_proj.", - ".add_k_proj.", - ".add_v_proj.", + gguf_to_hf_mapper = WeightsMapper( + # double_stream_modulation + orig_to_new_prefix = FLUX2_TRANSFORMER_KEYS_RENAME_DICT | FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP, + orig_to_new_substr = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP | FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP, ) - @staticmethod - def is_compatible(od_config, model: torch.nn.Module, source) -> bool: - model_class = od_config.model_class_name or "" - if model_class.startswith("Flux2Klein"): - return True - cfg = od_config.tf_model_config - if cfg is not None: - model_type = str(cfg.get("model_type", "")).lower() - if model_type.startswith("flux2"): - return True - # Fallback: Flux2 transformer has single_transformer_blocks - prefix = getattr(source, "prefix", "") - target = model.get_submodule(prefix.rstrip(".")) if prefix else model - return hasattr(target, "single_transformer_blocks") - def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: try: import gguf # type: ignore From 02f9c1b3b1a9d61680057057b718c646bf73e895 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 20 Feb 2026 16:11:13 +0800 Subject: [PATCH 43/62] update Signed-off-by: Isotr0py --- .../model_loader/gguf_adapters/flux2_klein.py | 117 +++++++++--------- 1 file changed, 60 insertions(+), 57 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index 17cd167d0f..4b6101e8c7 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -7,6 +7,7 @@ import torch from vllm.model_executor.models.utils import WeightsMapper +from vllm.model_executor.model_loader.weight_utils import gguf_quant_weights_iterator from .base import GGUFAdapter, MappedTensor @@ -68,63 +69,65 @@ class Flux2KleinGGUFAdapter(GGUFAdapter): ) def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: - try: - import gguf # type: ignore - except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc - - reader = gguf.GGUFReader(self.gguf_file) - allowed_names = self._build_allowed_names() - param_names = self._build_param_names() - mapped: list[MappedTensor] = [] - - for tensor in reader.tensors: - for mapped_tensor in self._map_tensor_name(tensor): - if ( - mapped_tensor.name not in allowed_names - and self._resolve_linear_qweight(mapped_tensor.name, param_names) is None - ): - continue - mapped.append(mapped_tensor) - - if not mapped: - raise RuntimeError( - "No GGUF tensors were mapped for Flux2 GGUF loader. Please verify the GGUF file and model structure." - ) - - for item in mapped: - linear_qweight = self._resolve_linear_qweight(item.name, param_names) - is_linear_weight = linear_qweight is not None - if not is_linear_weight: - continue - weight_type_name = linear_qweight.replace("qweight", "qweight_type") - yield weight_type_name, torch.tensor(item.tensor_type) - - for item in mapped: - weight = item.tensor.data - if item.row_slice is not None: - weight = weight[item.row_slice] - weight_type = item.tensor_type - linear_qweight = self._resolve_linear_qweight(item.name, param_names) - is_linear_weight = linear_qweight is not None - if is_linear_weight: - name = linear_qweight - else: - name = item.name - - if weight_type.name == "BF16" and weight.dtype == np.uint8: - weight = weight.view(np.uint16) - if reader.byte_order == "S": - weight = weight.byteswap() - param = torch.tensor(weight).view(torch.bfloat16) - else: - param = torch.tensor(weight) - - if item.swap_scale_shift: - shift, scale = param.chunk(2, dim=0) - param = torch.cat([scale, shift], dim=0) - - yield name, param + weights = gguf_quant_weights_iterator(self.gguf_file, {}) + yield from self.gguf_to_hf_mapper.apply(weights) + # try: + # import gguf # type: ignore + # except Exception as exc: # pragma: no cover - dependency error + # raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc + + # reader = gguf.GGUFReader(self.gguf_file) + # allowed_names = self._build_allowed_names() + # param_names = self._build_param_names() + # mapped: list[MappedTensor] = [] + + # for tensor in reader.tensors: + # for mapped_tensor in self._map_tensor_name(tensor): + # if ( + # mapped_tensor.name not in allowed_names + # and self._resolve_linear_qweight(mapped_tensor.name, param_names) is None + # ): + # continue + # mapped.append(mapped_tensor) + + # if not mapped: + # raise RuntimeError( + # "No GGUF tensors were mapped for Flux2 GGUF loader. Please verify the GGUF file and model structure." + # ) + + # for item in mapped: + # linear_qweight = self._resolve_linear_qweight(item.name, param_names) + # is_linear_weight = linear_qweight is not None + # if not is_linear_weight: + # continue + # weight_type_name = linear_qweight.replace("qweight", "qweight_type") + # yield weight_type_name, torch.tensor(item.tensor_type) + + # for item in mapped: + # weight = item.tensor.data + # if item.row_slice is not None: + # weight = weight[item.row_slice] + # weight_type = item.tensor_type + # linear_qweight = self._resolve_linear_qweight(item.name, param_names) + # is_linear_weight = linear_qweight is not None + # if is_linear_weight: + # name = linear_qweight + # else: + # name = item.name + + # if weight_type.name == "BF16" and weight.dtype == np.uint8: + # weight = weight.view(np.uint16) + # if reader.byte_order == "S": + # weight = weight.byteswap() + # param = torch.tensor(weight).view(torch.bfloat16) + # else: + # param = torch.tensor(weight) + + # if item.swap_scale_shift: + # shift, scale = param.chunk(2, dim=0) + # param = torch.cat([scale, shift], dim=0) + + # yield name, param def _map_tensor_name(self, tensor) -> list[MappedTensor]: name = tensor.name From b1091fb0708e263924b8a46397afd67e9d048ea5 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 20 Feb 2026 23:43:07 +0800 Subject: [PATCH 44/62] update Signed-off-by: Isotr0py --- .../model_loader/diffusers_loader.py | 21 ++++-- .../model_loader/gguf_adapters/flux2_klein.py | 66 ++++++++++++++++++- .../flux2_klein/flux2_klein_transformer.py | 30 +++++---- .../flux2_klein/pipeline_flux2_klein.py | 6 +- 4 files changed, 98 insertions(+), 25 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index d83780be55..36d9f3ca07 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -282,12 +282,12 @@ def load_weights(self, model: nn.Module) -> None: # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if loaded_weights is not None: - _ = weights_to_load - loaded_weights - # if weights_not_loaded: - # raise ValueError( - # "Following weights were not initialized from " - # f"checkpoint: {weights_not_loaded}" - # ) + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) def _is_gguf_quantization(self, od_config: OmniDiffusionConfig) -> bool: quant_config = od_config.quantization_config @@ -433,4 +433,11 @@ def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConf loaded |= model.load_weights(hf_iter) else: loaded |= model.load_weights(self._get_weights_iterator(source)) - return loaded + + weights_to_load = {name for name, _ in model.named_parameters()} + weights_not_loaded = weights_to_load - loaded + # if weights_not_loaded: + # raise ValueError( + # "Following weights were not initialized from " + # f"checkpoint: {weights_not_loaded}" + # ) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index 4b6101e8c7..a67729e6e4 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -3,16 +3,17 @@ from collections.abc import Generator +import gguf import numpy as np import torch from vllm.model_executor.models.utils import WeightsMapper -from vllm.model_executor.model_loader.weight_utils import gguf_quant_weights_iterator from .base import GGUFAdapter, MappedTensor FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { + "single_blocks.": "single_transformer_blocks.", # Image and text input projections "img_in": "x_embedder", "txt_in": "context_embedder", @@ -35,6 +36,7 @@ } FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { + "double_blocks.": "transformer_blocks.", # Handle fused QKV projections separately as we need to break into Q, K, V projections "img_attn.norm.query_norm": "attn.norm_q", "img_attn.norm.key_norm": "attn.norm_k", @@ -47,7 +49,7 @@ "txt_mlp.0": "ff_context.linear_in", "txt_mlp.2": "ff_context.linear_out", # Additional for fuse qkv - "img_attn.qkv": "attn.to_qkv_mlp_proj", + "img_attn.qkv": "attn.to_qkv", "txt_attn.qkv": "attn.add_kv_proj", } @@ -59,6 +61,51 @@ } +def gguf_quant_weights_iterator( + gguf_file: str, gguf_to_hf_name_map: dict[str, str] +) -> Generator[tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors. + Be careful of the order of yielding weight types and weights data, + we have to yield all weight types first before yielding any weights. + Otherwise it would cause issue when loading weights with for packed + layer with different quant types. + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + weight_type = tensor.tensor_type + name = tensor.name + + if weight_type.name not in ("F32", "BF16", "F16"): + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + weight = tensor.data + weight_type = tensor.tensor_type + name = tensor.name + if weight_type.name not in ("F32", "BF16", "F16"): + name = name.replace("weight", "qweight") + elif name.endswith(".scale"): + name = name.replace(".scale", ".weight") + if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: + # BF16 is currently the only "quantization" type that isn't + # actually quantized but is read as a raw byte tensor. + # Reinterpret as `torch.bfloat16` tensor. + weight = weight.view(np.uint16) + if reader.byte_order == "S": + # GGUF endianness != system endianness + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) + yield name, param + + class Flux2KleinGGUFAdapter(GGUFAdapter): """GGUF adapter for Flux2-Klein models with qkv splitting and adaLN swap.""" @@ -69,8 +116,18 @@ class Flux2KleinGGUFAdapter(GGUFAdapter): ) def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: + def custom_weights_adapter(weights): + for name, weight in weights: + # Handle the special case for adaLN modulation parameters that require swapping shift and scale + if name == "norm_out.linear.weight": + shift, scale = weight.chunk(2, dim=0) + weight = torch.cat([scale, shift], dim=0) + yield name, weight + else: + yield name, weight weights = gguf_quant_weights_iterator(self.gguf_file, {}) - yield from self.gguf_to_hf_mapper.apply(weights) + weights = self.gguf_to_hf_mapper.apply(weights) + yield from custom_weights_adapter(weights) # try: # import gguf # type: ignore # except Exception as exc: # pragma: no cover - dependency error @@ -220,6 +277,9 @@ def _map_single_blocks(self, tensor) -> list[MappedTensor]: "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", "single_stream_modulation.lin": "single_stream_modulation.linear", "final_layer.linear": "proj_out", + # prefix + "double_blocks.": "transformer_blocks.", + "single_blocks.": "single_transformer_blocks.", } _FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index 5ac0b4e1a9..d73446c23c 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -755,9 +755,9 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), + (".to_qkv.", ".to_q.", "q"), + (".to_qkv.", ".to_k.", "k"), + (".to_qkv.", ".to_v.", "v"), (".add_kv_proj", ".add_q_proj", "q"), (".add_kv_proj", ".add_k_proj", "k"), (".add_kv_proj", ".add_v_proj", "v"), @@ -771,28 +771,30 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - if "to_qkvkv_mlp_proj" in name: - name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") - if "to_qkv_mlp_proj" in name: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - continue + # if "to_qkvkv_mlp_proj" in name: + # name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") + # if "to_qkv_mlp_proj" in name: + # param = params_dict[name] + # weight_loader = getattr(param, "weight_loader", default_weight_loader) + # weight_loader(param, loaded_weight) + # loaded_params.add(name) + # continue # GGUF fused QKV weights already target .to_qkv/.add_kv_proj. # Avoid substring replacement that would duplicate "qkv". - is_fused_qkv = ".to_qkv." in name or ".add_kv_proj." in name + # is_fused_qkv = ".to_qkv." in name or ".add_kv_proj." in name + print(name, loaded_weight.shape) for param_name, weight_name, shard_id in stacked_params_mapping: - if is_fused_qkv or weight_name not in name: + if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) + loaded_params.add(name) return loaded_params diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 6fd2de3c94..51bafb28ee 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -995,4 +995,8 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loaded_weights = loader.load_weights(weights) + # Record other components not tracked by AutoWeightsLoader + loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} + loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} + return loaded_weights From a4e0ffef1c045aea40c795cc37dd8cddbc907266 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 21 Feb 2026 00:00:35 +0800 Subject: [PATCH 45/62] clean Signed-off-by: Isotr0py --- .../model_loader/diffusers_loader.py | 12 +- .../model_loader/gguf_adapters/base.py | 48 +++- .../model_loader/gguf_adapters/flux2_klein.py | 241 +----------------- .../flux2_klein/flux2_klein_transformer.py | 12 - 4 files changed, 60 insertions(+), 253 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 36d9f3ca07..d08fef5f78 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -284,10 +284,7 @@ def load_weights(self, model: nn.Module) -> None: if loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}" - ) + raise ValueError(f"Following weights were not initialized from checkpoint: {weights_not_loaded}") def _is_gguf_quantization(self, od_config: OmniDiffusionConfig) -> bool: quant_config = od_config.quantization_config @@ -436,8 +433,5 @@ def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConf weights_to_load = {name for name, _ in model.named_parameters()} weights_not_loaded = weights_to_load - loaded - # if weights_not_loaded: - # raise ValueError( - # "Following weights were not initialized from " - # f"checkpoint: {weights_not_loaded}" - # ) + if weights_not_loaded: + raise ValueError(f"Following weights were not initialized from checkpoint: {weights_not_loaded}") diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py index 4523921b69..013a249a55 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -5,8 +5,9 @@ from dataclasses import dataclass from typing import Any +import gguf +import numpy as np import torch -from vllm.model_executor.model_loader.weight_utils import gguf_quant_weights_iterator @dataclass @@ -40,8 +41,7 @@ def is_compatible(od_config, model: torch.nn.Module, source) -> bool: return True def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: - name_map = self._build_gguf_name_map() - return gguf_quant_weights_iterator(self.gguf_file, name_map) + return gguf_quant_weights_iterator(self.gguf_file) def _get_target_module(self) -> torch.nn.Module: prefix = getattr(self.source, "prefix", "") @@ -205,3 +205,45 @@ def split_name(name: str) -> tuple[str, str]: if not gguf_to_model_map: raise RuntimeError(f"No GGUF tensors were mapped for model_class_name={self.od_config.model_class_name!r}.") return gguf_to_model_map + + +# FIXME(Isotr0py): Sync implemnentation with upstream vLLM? +def gguf_quant_weights_iterator(gguf_file: str) -> Generator[tuple[str, torch.Tensor]]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors. + Be careful of the order of yielding weight types and weights data, + we have to yield all weight types first before yielding any weights. + Otherwise it would cause issue when loading weights with for packed + layer with different quant types. + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + weight_type = tensor.tensor_type + name = tensor.name + + if weight_type.name not in ("F32", "BF16", "F16"): + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + weight = tensor.data + weight_type = tensor.tensor_type + name = tensor.name + if weight_type.name not in ("F32", "BF16", "F16"): + name = name.replace("weight", "qweight") + if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: + # BF16 is currently the only "quantization" type that isn't + # actually quantized but is read as a raw byte tensor. + # Reinterpret as `torch.bfloat16` tensor. + weight = weight.view(np.uint16) + if reader.byte_order == "S": + # GGUF endianness != system endianness + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) + yield name, param diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index a67729e6e4..f16a2af1df 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -1,16 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from collections.abc import Generator +from collections.abc import Iterable -import gguf -import numpy as np import torch - from vllm.model_executor.models.utils import WeightsMapper -from .base import GGUFAdapter, MappedTensor - +from .base import GGUFAdapter, gguf_quant_weights_iterator FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { "single_blocks.": "single_transformer_blocks.", @@ -61,243 +57,30 @@ } -def gguf_quant_weights_iterator( - gguf_file: str, gguf_to_hf_name_map: dict[str, str] -) -> Generator[tuple[str, torch.Tensor], None, None]: - """ - Iterate over the quant weights in the model gguf files and convert - them to torch tensors. - Be careful of the order of yielding weight types and weights data, - we have to yield all weight types first before yielding any weights. - Otherwise it would cause issue when loading weights with for packed - layer with different quant types. - """ - - reader = gguf.GGUFReader(gguf_file) - - for tensor in reader.tensors: - weight_type = tensor.tensor_type - name = tensor.name - - if weight_type.name not in ("F32", "BF16", "F16"): - weight_type_name = name.replace("weight", "qweight_type") - weight_type = torch.tensor(weight_type) - yield weight_type_name, weight_type - - for tensor in reader.tensors: - weight = tensor.data - weight_type = tensor.tensor_type - name = tensor.name - if weight_type.name not in ("F32", "BF16", "F16"): - name = name.replace("weight", "qweight") - elif name.endswith(".scale"): - name = name.replace(".scale", ".weight") - if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: - # BF16 is currently the only "quantization" type that isn't - # actually quantized but is read as a raw byte tensor. - # Reinterpret as `torch.bfloat16` tensor. - weight = weight.view(np.uint16) - if reader.byte_order == "S": - # GGUF endianness != system endianness - weight = weight.byteswap() - param = torch.tensor(weight).view(torch.bfloat16) - else: - param = torch.tensor(weight) - yield name, param - - class Flux2KleinGGUFAdapter(GGUFAdapter): """GGUF adapter for Flux2-Klein models with qkv splitting and adaLN swap.""" gguf_to_hf_mapper = WeightsMapper( # double_stream_modulation - orig_to_new_prefix = FLUX2_TRANSFORMER_KEYS_RENAME_DICT | FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP, - orig_to_new_substr = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP | FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP, + orig_to_new_prefix=FLUX2_TRANSFORMER_KEYS_RENAME_DICT | FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP, + orig_to_new_substr=FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP | FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP, ) - def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: - def custom_weights_adapter(weights): + def weights_iterator(self) -> Iterable[tuple[str, torch.Tensor]]: + def custom_weights_adapter( + weights: Iterable[tuple[str, torch.Tensor]], + ) -> Iterable[tuple[str, torch.Tensor]]: for name, weight in weights: # Handle the special case for adaLN modulation parameters that require swapping shift and scale + if name.endswith(".scale"): + name = name.replace(".scale", ".weight") if name == "norm_out.linear.weight": shift, scale = weight.chunk(2, dim=0) weight = torch.cat([scale, shift], dim=0) yield name, weight else: yield name, weight - weights = gguf_quant_weights_iterator(self.gguf_file, {}) + + weights = gguf_quant_weights_iterator(self.gguf_file) weights = self.gguf_to_hf_mapper.apply(weights) yield from custom_weights_adapter(weights) - # try: - # import gguf # type: ignore - # except Exception as exc: # pragma: no cover - dependency error - # raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc - - # reader = gguf.GGUFReader(self.gguf_file) - # allowed_names = self._build_allowed_names() - # param_names = self._build_param_names() - # mapped: list[MappedTensor] = [] - - # for tensor in reader.tensors: - # for mapped_tensor in self._map_tensor_name(tensor): - # if ( - # mapped_tensor.name not in allowed_names - # and self._resolve_linear_qweight(mapped_tensor.name, param_names) is None - # ): - # continue - # mapped.append(mapped_tensor) - - # if not mapped: - # raise RuntimeError( - # "No GGUF tensors were mapped for Flux2 GGUF loader. Please verify the GGUF file and model structure." - # ) - - # for item in mapped: - # linear_qweight = self._resolve_linear_qweight(item.name, param_names) - # is_linear_weight = linear_qweight is not None - # if not is_linear_weight: - # continue - # weight_type_name = linear_qweight.replace("qweight", "qweight_type") - # yield weight_type_name, torch.tensor(item.tensor_type) - - # for item in mapped: - # weight = item.tensor.data - # if item.row_slice is not None: - # weight = weight[item.row_slice] - # weight_type = item.tensor_type - # linear_qweight = self._resolve_linear_qweight(item.name, param_names) - # is_linear_weight = linear_qweight is not None - # if is_linear_weight: - # name = linear_qweight - # else: - # name = item.name - - # if weight_type.name == "BF16" and weight.dtype == np.uint8: - # weight = weight.view(np.uint16) - # if reader.byte_order == "S": - # weight = weight.byteswap() - # param = torch.tensor(weight).view(torch.bfloat16) - # else: - # param = torch.tensor(weight) - - # if item.swap_scale_shift: - # shift, scale = param.chunk(2, dim=0) - # param = torch.cat([scale, shift], dim=0) - - # yield name, param - - def _map_tensor_name(self, tensor) -> list[MappedTensor]: - name = tensor.name - - if name.startswith("double_blocks."): - return self._map_double_blocks(tensor) - if name.startswith("single_blocks."): - return self._map_single_blocks(tensor) - if name.startswith("final_layer.adaLN_modulation.1") and name.endswith(".weight"): - return [ - MappedTensor( - name="norm_out.linear.weight", - tensor=tensor, - tensor_type=tensor.tensor_type, - swap_scale_shift=True, - ) - ] - - for src, dst in _FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items(): - name = name.replace(src, dst) - - return [ - MappedTensor( - name=name, - tensor=tensor, - tensor_type=tensor.tensor_type, - ) - ] - - def _map_double_blocks(self, tensor) -> list[MappedTensor]: - name = tensor.name - parts = name.split(".") - block_idx = parts[1] - within_block_name = ".".join(parts[2:-1]) - param_type = parts[-1] - if param_type == "scale": - param_type = "weight" - - if "qkv" in within_block_name: - if "img_attn" in within_block_name: - q_name = f"transformer_blocks.{block_idx}.attn.to_q.{param_type}" - k_name = f"transformer_blocks.{block_idx}.attn.to_k.{param_type}" - v_name = f"transformer_blocks.{block_idx}.attn.to_v.{param_type}" - elif "txt_attn" in within_block_name: - q_name = f"transformer_blocks.{block_idx}.attn.add_q_proj.{param_type}" - k_name = f"transformer_blocks.{block_idx}.attn.add_k_proj.{param_type}" - v_name = f"transformer_blocks.{block_idx}.attn.add_v_proj.{param_type}" - else: - return [] - - weight = tensor.data - dim0 = weight.shape[0] - split = dim0 // 3 - return [ - MappedTensor(q_name, tensor, tensor.tensor_type, slice(0, split)), - MappedTensor(k_name, tensor, tensor.tensor_type, slice(split, 2 * split)), - MappedTensor(v_name, tensor, tensor.tensor_type, slice(2 * split, 3 * split)), - ] - - mapped_name = _FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP.get(within_block_name) - if mapped_name is None: - return [] - target = f"transformer_blocks.{block_idx}.{mapped_name}.{param_type}" - return [MappedTensor(target, tensor, tensor.tensor_type)] - - def _map_single_blocks(self, tensor) -> list[MappedTensor]: - name = tensor.name - parts = name.split(".") - block_idx = parts[1] - within_block_name = ".".join(parts[2:-1]) - param_type = parts[-1] - if param_type == "scale": - param_type = "weight" - - mapped_name = _FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP.get(within_block_name) - if mapped_name is None: - return [] - target = f"single_transformer_blocks.{block_idx}.{mapped_name}.{param_type}" - return [MappedTensor(target, tensor, tensor.tensor_type)] - - -_FLUX2_TRANSFORMER_KEYS_RENAME_DICT = { - "img_in": "x_embedder", - "txt_in": "context_embedder", - "time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1", - "time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2", - "guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1", - "guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2", - "double_stream_modulation_img.lin": "double_stream_modulation_img.linear", - "double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear", - "single_stream_modulation.lin": "single_stream_modulation.linear", - "final_layer.linear": "proj_out", - # prefix - "double_blocks.": "transformer_blocks.", - "single_blocks.": "single_transformer_blocks.", -} - -_FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = { - "img_attn.norm.query_norm": "attn.norm_q", - "img_attn.norm.key_norm": "attn.norm_k", - "img_attn.proj": "attn.to_out.0", - "img_mlp.0": "ff.linear_in", - "img_mlp.2": "ff.linear_out", - "txt_attn.norm.query_norm": "attn.norm_added_q", - "txt_attn.norm.key_norm": "attn.norm_added_k", - "txt_attn.proj": "attn.to_add_out", - "txt_mlp.0": "ff_context.linear_in", - "txt_mlp.2": "ff_context.linear_out", -} - -_FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = { - "linear1": "attn.to_qkv_mlp_proj", - "linear2": "attn.to_out", - "norm.query_norm": "attn.norm_q", - "norm.key_norm": "attn.norm_k", -} diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index d73446c23c..ff9c79e167 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -771,18 +771,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: - # if "to_qkvkv_mlp_proj" in name: - # name = name.replace("to_qkvkv_mlp_proj", "to_qkv_mlp_proj") - # if "to_qkv_mlp_proj" in name: - # param = params_dict[name] - # weight_loader = getattr(param, "weight_loader", default_weight_loader) - # weight_loader(param, loaded_weight) - # loaded_params.add(name) - # continue - # GGUF fused QKV weights already target .to_qkv/.add_kv_proj. - # Avoid substring replacement that would duplicate "qkv". - # is_fused_qkv = ".to_qkv." in name or ".add_kv_proj." in name - print(name, loaded_weight.shape) for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue From bd2cce6eff8919af5d4922750c7f5c3bce1ff7ec Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 21 Feb 2026 00:43:04 +0800 Subject: [PATCH 46/62] draft Signed-off-by: Isotr0py --- .../model_loader/gguf_adapters/z_image.py | 278 ++---------------- .../models/z_image/pipeline_z_image.py | 6 +- 2 files changed, 24 insertions(+), 260 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 6380cef611..daaba5f102 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -1,272 +1,32 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from collections.abc import Generator +from collections.abc import Iterable -import numpy as np import torch +from vllm.model_executor.models.utils import WeightsMapper -from .base import GGUFAdapter, MappedTensor +from .base import GGUFAdapter, gguf_quant_weights_iterator + + +Z_IMAGE_KEYS_RENAME_DICT = { + "final_layer.": "all_final_layer.2-1.", + "x_embedder.": "all_x_embedder.2-1.", + ".attention.out.bias": ".attention.to_out.0.bias", + ".attention.k_norm": ".attention.norm_k.weight", + ".attention.q_norm": ".attention.norm_q.weight", + ".attention.out.weight": ".attention.to_out.0.weight", + "model.diffusion_model.": "", +} class ZImageGGUFAdapter(GGUFAdapter): """GGUF adapter for Z-Image models with QKV/FFN shard support.""" - _include_qkv_virtuals = True - _include_to_out_virtuals = True - _include_w13_virtuals = True - _shard_tokens = ( - ".to_q.", - ".to_k.", - ".to_v.", - ".w1.", - ".w3.", + gguf_to_hf_mapper = WeightsMapper( + orig_to_new_substr=Z_IMAGE_KEYS_RENAME_DICT, ) - @staticmethod - def is_compatible(od_config, model: torch.nn.Module, source) -> bool: - model_class = od_config.model_class_name or "" - if model_class.startswith("ZImage"): - return True - cfg = od_config.tf_model_config - if cfg is not None: - model_type = str(cfg.get("model_type", "")).lower() - if model_type in {"z_image", "zimage", "z-image"}: - return True - return False - - def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: - try: - import gguf # type: ignore - except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc - - reader = gguf.GGUFReader(self.gguf_file) - gguf_name_map = self._build_gguf_name_map(reader) - allowed_names = self._build_allowed_names() - param_names = self._build_param_names() - mapped: list[MappedTensor] = [] - - for tensor in reader.tensors: - for mapped_tensor in self._map_tensor_name(tensor, gguf_name_map): - linear_qweight = self._resolve_linear_qweight(mapped_tensor.name, param_names) - if mapped_tensor.name not in allowed_names and linear_qweight is None: - continue - if linear_qweight is None and tensor.tensor_type.name not in ("F32", "BF16", "F16"): - # Skip quantized tensors that map to non-quantized parameters. - continue - mapped.append(mapped_tensor) - - if not mapped: - raise RuntimeError( - "No GGUF tensors were mapped for Z-Image GGUF loader. Please verify the GGUF file and model structure." - ) - - for item in mapped: - linear_qweight = self._resolve_linear_qweight(item.name, param_names) - if linear_qweight is None: - continue - weight_type_name = linear_qweight.replace("qweight", "qweight_type") - yield weight_type_name, torch.tensor(item.tensor_type) - - for item in mapped: - weight = item.tensor.data - if item.row_slice is not None: - weight = weight[item.row_slice] - weight_type = item.tensor_type - linear_qweight = self._resolve_linear_qweight(item.name, param_names) - if linear_qweight is not None: - name = linear_qweight - else: - name = item.name - - if weight_type.name == "BF16" and weight.dtype == np.uint8: - weight = weight.view(np.uint16) - if reader.byte_order == "S": - weight = weight.byteswap() - param = torch.tensor(weight).view(torch.bfloat16) - else: - param = torch.tensor(weight) - - yield name, param - - def _normalize_name(self, name: str) -> str: - if name.endswith(".scale"): - name = name[:-6] + ".weight" - if name.endswith("_weight"): - name = name[:-7] + ".weight" - if ".to_out.0." in name: - name = name.replace(".to_out.0.", ".to_out.") - return name - - def _get_patch_key(self) -> str: - prefix = getattr(self.source, "prefix", "") - target = self.model.get_submodule(prefix.rstrip(".")) if prefix else self.model - if hasattr(target, "all_x_embedder"): - keys = list(getattr(target, "all_x_embedder").keys()) - if "2-1" in keys: - # Default to the standard Z-Image Turbo patch/frequency config - # (patch_size=2, f_patch_size=1) when available. - return "2-1" - if keys: - return sorted(keys)[0] - return "2-1" - - def _apply_zimage_renames(self, name: str) -> str: - if name.startswith("model.diffusion_model."): - name = name.replace("model.diffusion_model.", "", 1) - - patch_key = self._get_patch_key() - if name.startswith("x_embedder.") and not name.startswith("all_x_embedder."): - name = name.replace("x_embedder.", f"all_x_embedder.{patch_key}.", 1) - if name.startswith("final_layer.") and not name.startswith("all_final_layer."): - name = name.replace("final_layer.", f"all_final_layer.{patch_key}.", 1) - - name = name.replace(".attention.out.bias", ".attention.to_out.0.bias") - name = name.replace(".attention.out.weight", ".attention.to_out.0.weight") - name = name.replace(".attention.k_norm.weight", ".attention.norm_k.weight") - name = name.replace(".attention.q_norm.weight", ".attention.norm_q.weight") - return name - - def _map_tensor_name(self, tensor, gguf_name_map: dict[str, str]) -> list[MappedTensor]: - name = gguf_name_map.get(tensor.name) - if name is None: - name = self._normalize_name(tensor.name) - name = self._apply_zimage_renames(name) - - if ".attention.qkv.weight" in name: - weight = tensor.data - dim0 = weight.shape[0] - split = dim0 // 3 - return [ - MappedTensor( - name=name.replace(".attention.qkv.weight", ".attention.to_q.weight"), - tensor=tensor, - tensor_type=tensor.tensor_type, - row_slice=slice(0, split), - ), - MappedTensor( - name=name.replace(".attention.qkv.weight", ".attention.to_k.weight"), - tensor=tensor, - tensor_type=tensor.tensor_type, - row_slice=slice(split, 2 * split), - ), - MappedTensor( - name=name.replace(".attention.qkv.weight", ".attention.to_v.weight"), - tensor=tensor, - tensor_type=tensor.tensor_type, - row_slice=slice(2 * split, 3 * split), - ), - ] - - return [ - MappedTensor( - name=name, - tensor=tensor, - tensor_type=tensor.tensor_type, - ) - ] - - def _build_gguf_name_map(self, reader) -> dict[str, str]: - try: - import gguf # type: ignore - except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc - - gguf_tensor_names = {tensor.name for tensor in reader.tensors} - - def resolve_model_type() -> str: - cfg = self.od_config.tf_model_config - model_type = None - if cfg is not None: - model_type = cfg.get("model_type") - if model_type: - return model_type - model_class = self.od_config.model_class_name or "" - if model_class.startswith("ZImage"): - return "z_image" - raise ValueError("Cannot infer gguf model_type for Z-Image.") - - def resolve_arch(model_type: str): - for key, value in gguf.MODEL_ARCH_NAMES.items(): - if value == model_type: - return key - raise RuntimeError(f"Unknown gguf model_type: {model_type}") - - def resolve_num_layers(target_module: torch.nn.Module) -> int: - if hasattr(target_module, "layers"): - return len(getattr(target_module, "layers")) - cfg = self.od_config.tf_model_config - if cfg is not None: - for key in ("num_hidden_layers", "num_layers", "n_layers"): - value = cfg.get(key) - if isinstance(value, int) and value > 0: - return value - raise ValueError("Cannot infer gguf num_layers for Z-Image.") - - def get_target_module(root: torch.nn.Module, prefix: str) -> torch.nn.Module: - if not prefix: - return root - prefix = prefix.rstrip(".") - if hasattr(root, "get_submodule"): - return root.get_submodule(prefix) - current = root - for part in prefix.split("."): - current = getattr(current, part) - return current - - def split_name(name: str) -> tuple[str, str]: - if name.endswith("_weight"): - return name[:-7], "weight" - if "." in name: - base, suffix = name.rsplit(".", 1) - return base, suffix - return name, "" - - model_type = resolve_model_type() - if model_type in {"z_image", "zimage", "z-image"}: - # gguf-py does not register a Z-Image architecture, so we rely on - # direct tensor names from the GGUF file. - return {} - try: - arch = resolve_arch(model_type) - except RuntimeError: - # Fallback: some gguf versions may not register z_image arch. - # In that case, rely on direct tensor names from the GGUF file. - return {} - target_module = get_target_module(self.model, self.source.prefix) - num_layers = resolve_num_layers(target_module) - name_map = gguf.get_tensor_name_map(arch, num_layers) - - candidate_names = {name for name, _ in target_module.named_parameters()} - candidate_names.update(name for name, _ in target_module.named_buffers()) - for name in list(candidate_names): - if ".to_qkv." in name: - candidate_names.add(name.replace(".to_qkv.", ".to_q.")) - candidate_names.add(name.replace(".to_qkv.", ".to_k.")) - candidate_names.add(name.replace(".to_qkv.", ".to_v.")) - if ".w13." in name: - candidate_names.add(name.replace(".w13.", ".w1.")) - candidate_names.add(name.replace(".w13.", ".w3.")) - if ".to_out." in name: - candidate_names.add(name.replace(".to_out.", ".to_out.0.")) - - gguf_to_model_map: dict[str, str] = {} - for name in candidate_names: - base_name, suffix = split_name(name) - gguf_base = name_map.get_name(base_name) - if gguf_base is None: - continue - candidates = [] - if suffix: - candidates.append(f"{gguf_base}.{suffix}") - if suffix == "weight": - candidates.append(f"{gguf_base}.scale") - else: - candidates.append(gguf_base) - gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) - if gguf_name is None: - continue - gguf_to_model_map[gguf_name] = name - - return gguf_to_model_map + def weights_iterator(self) -> Iterable[tuple[str, torch.Tensor]]: + weights = gguf_quant_weights_iterator(self.gguf_file) + yield from self.gguf_to_hf_mapper.apply(weights) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index e025440a79..25d006a349 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -644,4 +644,8 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loaded_weights = loader.load_weights(weights) + # Record other components not tracked by AutoWeightsLoader + loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} + loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} + return loaded_weights From 8f7edd67d509731f98c2a9c4ba65ac2e55000f99 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 21 Feb 2026 23:50:40 +0800 Subject: [PATCH 47/62] fix Signed-off-by: Isotr0py --- .../model_loader/gguf_adapters/base.py | 4 +-- .../model_loader/gguf_adapters/z_image.py | 8 ++--- .../models/z_image/z_image_transformer.py | 35 +++++++++++-------- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py index 013a249a55..6d82cdcf2c 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -224,7 +224,7 @@ def gguf_quant_weights_iterator(gguf_file: str) -> Generator[tuple[str, torch.Te weight_type = tensor.tensor_type name = tensor.name - if weight_type.name not in ("F32", "BF16", "F16"): + if weight_type.name not in ("F32", "F16"): weight_type_name = name.replace("weight", "qweight_type") weight_type = torch.tensor(weight_type) yield weight_type_name, weight_type @@ -233,7 +233,7 @@ def gguf_quant_weights_iterator(gguf_file: str) -> Generator[tuple[str, torch.Te weight = tensor.data weight_type = tensor.tensor_type name = tensor.name - if weight_type.name not in ("F32", "BF16", "F16"): + if weight_type.name not in ("F32", "F16"): name = name.replace("weight", "qweight") if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: # BF16 is currently the only "quantization" type that isn't diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index daaba5f102..677d79781a 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -12,10 +12,10 @@ Z_IMAGE_KEYS_RENAME_DICT = { "final_layer.": "all_final_layer.2-1.", "x_embedder.": "all_x_embedder.2-1.", - ".attention.out.bias": ".attention.to_out.0.bias", - ".attention.k_norm": ".attention.norm_k.weight", - ".attention.q_norm": ".attention.norm_q.weight", - ".attention.out.weight": ".attention.to_out.0.weight", + ".attention.qkv": ".attention.to_qkv", + ".attention.k_norm": ".attention.norm_k", + ".attention.q_norm": ".attention.norm_q", + ".attention.out": ".attention.to_out.0", "model.diffusion_model.": "", } diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index a4f073faa5..1d31aae683 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -30,6 +30,7 @@ MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, + ReplicatedLinear, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -207,21 +208,25 @@ def validate_zimage_tp_constraints( class TimestepEmbedder(nn.Module): - def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): + def __init__(self, out_size, mid_size=None, frequency_embedding_size=256, quant_config: "QuantizationConfig | None" = None): super().__init__() if mid_size is None: mid_size = out_size self.mlp = nn.Sequential( - nn.Linear( + ReplicatedLinear( frequency_embedding_size, mid_size, bias=True, + quant_config=quant_config, + return_bias=False, ), nn.SiLU(), - nn.Linear( + ReplicatedLinear( mid_size, out_size, bias=True, + quant_config=quant_config, + return_bias=False, ), ) @@ -241,7 +246,7 @@ def timestep_embedding(t, dim, max_period=10000): def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - weight_dtype = self.mlp[0].weight.dtype + weight_dtype = self.mlp[0].bias.dtype if weight_dtype.is_floating_point: t_freq = t_freq.to(weight_dtype) t_emb = self.mlp(t_freq) @@ -420,7 +425,7 @@ def __init__( self.modulation = modulation if modulation: self.adaLN_modulation = nn.Sequential( - nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), + ReplicatedLinear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config), ) def forward( @@ -473,14 +478,14 @@ def forward( class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): + def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.linear = ReplicatedLinear(hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False) self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ReplicatedLinear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False), ) def forward(self, x, c): @@ -652,10 +657,10 @@ def __init__( all_x_embedder = {} all_final_layer = {} for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + x_embedder = ReplicatedLinear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True, quant_config=quant_config, return_bias=False) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels, quant_config=quant_config) all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer self.all_x_embedder = nn.ModuleDict(all_x_embedder) @@ -690,10 +695,10 @@ def __init__( for layer_id in range(n_refiner_layers) ] ) - self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024, quant_config=quant_config) self.cap_embedder = nn.Sequential( RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear(cap_feat_dim, dim, bias=True), + ReplicatedLinear(cap_feat_dim, dim, bias=True, return_bias=False, quant_config=quant_config), ) self.x_pad_token = nn.Parameter(torch.empty((1, dim))) @@ -957,9 +962,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) # self-attn - (".to_qkv", ".to_q", "q"), - (".to_qkv", ".to_k", "k"), - (".to_qkv", ".to_v", "v"), + (".to_qkv.", ".to_q.", "q"), + (".to_qkv.", ".to_k.", "k"), + (".to_qkv.", ".to_v.", "v"), # ffn (".w13", ".w1", 0), (".w13", ".w3", 1), From 16b2dd8e8d7e0c8af9f68bc8f913fc6beec0d846 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 25 Feb 2026 15:14:32 +0800 Subject: [PATCH 48/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/gguf_adapters/base.py | 5 ---- .../models/z_image/z_image_transformer.py | 28 +++++++++++++++---- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py index 6d82cdcf2c..0eb55c43ca 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -101,11 +101,6 @@ def _resolve_linear_qweight(self, name: str, param_names: set[str]) -> str | Non return None def _build_gguf_name_map(self) -> dict[str, str]: - try: - import gguf # type: ignore - except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc - def resolve_model_type() -> str: cfg = self.od_config.tf_model_config model_type = None diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 1d31aae683..59fb351569 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -208,7 +208,9 @@ def validate_zimage_tp_constraints( class TimestepEmbedder(nn.Module): - def __init__(self, out_size, mid_size=None, frequency_embedding_size=256, quant_config: "QuantizationConfig | None" = None): + def __init__( + self, out_size, mid_size=None, frequency_embedding_size=256, quant_config: "QuantizationConfig | None" = None + ): super().__init__() if mid_size is None: mid_size = out_size @@ -425,7 +427,9 @@ def __init__( self.modulation = modulation if modulation: self.adaLN_modulation = nn.Sequential( - ReplicatedLinear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config), + ReplicatedLinear( + min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True, return_bias=False, quant_config=quant_config + ), ) def forward( @@ -481,11 +485,15 @@ class FinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, quant_config: "QuantizationConfig | None" = None): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = ReplicatedLinear(hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False) + self.linear = ReplicatedLinear( + hidden_size, out_channels, bias=True, quant_config=quant_config, return_bias=False + ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), - ReplicatedLinear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False), + ReplicatedLinear( + min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True, quant_config=quant_config, return_bias=False + ), ) def forward(self, x, c): @@ -657,10 +665,18 @@ def __init__( all_x_embedder = {} all_final_layer = {} for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): - x_embedder = ReplicatedLinear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True, quant_config=quant_config, return_bias=False) + x_embedder = ReplicatedLinear( + f_patch_size * patch_size * patch_size * in_channels, + dim, + bias=True, + quant_config=quant_config, + return_bias=False, + ) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder - final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels, quant_config=quant_config) + final_layer = FinalLayer( + dim, patch_size * patch_size * f_patch_size * self.out_channels, quant_config=quant_config + ) all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer self.all_x_embedder = nn.ModuleDict(all_x_embedder) From b0808f87c3a100aa518c62a680b2a3a8e9bf9344 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 25 Feb 2026 15:17:53 +0800 Subject: [PATCH 49/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/models/z_image/z_image_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 59fb351569..83d0ecb9b9 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -29,8 +29,8 @@ from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, - RowParallelLinear, ReplicatedLinear, + RowParallelLinear, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -210,7 +210,7 @@ def validate_zimage_tp_constraints( class TimestepEmbedder(nn.Module): def __init__( self, out_size, mid_size=None, frequency_embedding_size=256, quant_config: "QuantizationConfig | None" = None - ): + ): super().__init__() if mid_size is None: mid_size = out_size From f684cd62288c9975f356e1e27ed97e759f1900fa Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 25 Feb 2026 15:20:19 +0800 Subject: [PATCH 50/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index 677d79781a..ead0831152 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -8,7 +8,6 @@ from .base import GGUFAdapter, gguf_quant_weights_iterator - Z_IMAGE_KEYS_RENAME_DICT = { "final_layer.": "all_final_layer.2-1.", "x_embedder.": "all_x_embedder.2-1.", From 42a29058b6cdc4a29d2762dcb33c97828907d1e8 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 25 Feb 2026 15:53:45 +0800 Subject: [PATCH 51/62] fix comment 1 Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/diffusers_loader.py | 7 +++---- .../model_loader/gguf_adapters/flux2_klein.py | 12 ++++++++++++ .../diffusion/model_loader/gguf_adapters/z_image.py | 12 ++++++++++++ .../models/flux2_klein/pipeline_flux2_klein.py | 2 +- .../diffusion/models/z_image/pipeline_z_image.py | 2 +- 5 files changed, 29 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index d08fef5f78..e62993ffe7 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -301,12 +301,11 @@ def _is_gguf_quantization(self, od_config: OmniDiffusionConfig) -> bool: return True # Normal path: DiffusionQuantizationConfig - try: - is_gguf = quant_config.get_name() == "gguf" - except AttributeError: + if not hasattr(quant_config, "get_name"): # Fallback: if it carries gguf_model, treat as GGUF gguf_model = getattr(quant_config, "gguf_model", None) return bool(gguf_model) + is_gguf = quant_config.get_name() == "gguf" if not is_gguf: return False gguf_model = getattr(quant_config, "gguf_model", None) @@ -382,7 +381,7 @@ def _download_raw_gguf_url(self, url: str) -> str: ) os.close(tmp_fd) try: - with urlopen(url) as response, open(tmp_path, "wb") as out_file: + with urlopen(url, timeout=300) as response, open(tmp_path, "wb") as out_file: shutil.copyfileobj(response, out_file) os.replace(tmp_path, target_path) except Exception: diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py index f16a2af1df..4ee6439b28 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py @@ -60,6 +60,18 @@ class Flux2KleinGGUFAdapter(GGUFAdapter): """GGUF adapter for Flux2-Klein models with qkv splitting and adaLN swap.""" + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + model_class = od_config.model_class_name or "" + if model_class.startswith("Flux2"): + return True + cfg = od_config.tf_model_config + if cfg is not None: + model_type = str(cfg.get("model_type", "")).lower() + if model_type.startswith("flux"): + return True + return False + gguf_to_hf_mapper = WeightsMapper( # double_stream_modulation orig_to_new_prefix=FLUX2_TRANSFORMER_KEYS_RENAME_DICT | FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP, diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py index ead0831152..7d89633559 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py @@ -22,6 +22,18 @@ class ZImageGGUFAdapter(GGUFAdapter): """GGUF adapter for Z-Image models with QKV/FFN shard support.""" + @staticmethod + def is_compatible(od_config, model: torch.nn.Module, source) -> bool: + model_class = od_config.model_class_name or "" + if model_class.startswith("ZImage"): + return True + cfg = od_config.tf_model_config + if cfg is not None: + model_type = str(cfg.get("model_type", "")).lower() + if model_type in {"z_image", "zimage", "z-image"}: + return True + return False + gguf_to_hf_mapper = WeightsMapper( orig_to_new_substr=Z_IMAGE_KEYS_RENAME_DICT, ) diff --git a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py index 51bafb28ee..d43748380b 100644 --- a/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py +++ b/vllm_omni/diffusion/models/flux2_klein/pipeline_flux2_klein.py @@ -996,7 +996,7 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) loaded_weights = loader.load_weights(weights) - # Record other components not tracked by AutoWeightsLoader + # Record components loaded by diffusers submodules to satisfy strict checks. loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} return loaded_weights diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 25d006a349..2c6ef4c86f 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -645,7 +645,7 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) loaded_weights = loader.load_weights(weights) - # Record other components not tracked by AutoWeightsLoader + # Record components loaded by diffusers submodules to satisfy strict checks. loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} return loaded_weights From b4c7f6de112271cc8975a5fcd3f3c2c34a6236ba Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 25 Feb 2026 16:43:47 +0800 Subject: [PATCH 52/62] remove qwen-image Signed-off-by: David Chen <530634352@qq.com> --- .../model_loader/gguf_adapters/qwen_image.py | 209 ------------------ 1 file changed, 209 deletions(-) delete mode 100644 vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py b/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py deleted file mode 100644 index 6b5e2cbabc..0000000000 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -from collections.abc import Generator - -import numpy as np -import torch - -from .base import GGUFAdapter, MappedTensor - - -class QwenImageGGUFAdapter(GGUFAdapter): - """GGUF adapter for Qwen-Image models with QKV shard support.""" - - _include_qkv_virtuals = True - _include_add_kv_proj_virtuals = True - _include_to_out_virtuals = True - _shard_tokens = ( - ".to_q.", - ".to_k.", - ".to_v.", - ".add_q_proj.", - ".add_k_proj.", - ".add_v_proj.", - ) - - @staticmethod - def is_compatible(od_config, model: torch.nn.Module, source) -> bool: - model_class = od_config.model_class_name or "" - if model_class.startswith("QwenImage"): - return True - cfg = od_config.tf_model_config - if cfg is not None: - model_type = str(cfg.get("model_type", "")).lower() - if model_type.startswith("qwen_image"): - return True - return False - - def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: - try: - import gguf # type: ignore - except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc - - reader = gguf.GGUFReader(self.gguf_file) - gguf_name_map = self._build_gguf_name_map(reader) - allowed_names = self._build_allowed_names() - param_names = self._build_param_names() - mapped: list[MappedTensor] = [] - - for tensor in reader.tensors: - mapped_name = gguf_name_map.get(tensor.name) - if mapped_name is None: - mapped_name = self._normalize_name(tensor.name) - linear_qweight = self._resolve_linear_qweight(mapped_name, param_names) - if mapped_name not in allowed_names and linear_qweight is None: - continue - if linear_qweight is None and tensor.tensor_type.name not in ("F32", "BF16", "F16"): - # Skip quantized tensors that map to non-quantized parameters. - continue - mapped.append( - MappedTensor( - name=mapped_name, - tensor=tensor, - tensor_type=tensor.tensor_type, - ) - ) - - if not mapped: - raise RuntimeError( - "No GGUF tensors were mapped for Qwen-Image GGUF loader. " - "Please verify the GGUF file and model structure." - ) - - for item in mapped: - linear_qweight = self._resolve_linear_qweight(item.name, param_names) - if linear_qweight is None: - continue - weight_type_name = linear_qweight.replace("qweight", "qweight_type") - yield weight_type_name, torch.tensor(item.tensor_type) - - for item in mapped: - weight = item.tensor.data - weight_type = item.tensor_type - linear_qweight = self._resolve_linear_qweight(item.name, param_names) - if linear_qweight is not None: - name = linear_qweight - else: - name = item.name - - if weight_type.name == "BF16" and weight.dtype == np.uint8: - weight = weight.view(np.uint16) - if reader.byte_order == "S": - weight = weight.byteswap() - param = torch.tensor(weight).view(torch.bfloat16) - else: - param = torch.tensor(weight) - - yield name, param - - def _normalize_name(self, name: str) -> str: - if name.endswith(".scale"): - name = name[:-6] + ".weight" - if name.endswith("_weight"): - name = name[:-7] + ".weight" - if ".to_out.0." in name: - name = name.replace(".to_out.0.", ".to_out.") - return name - - def _build_gguf_name_map(self, reader) -> dict[str, str]: - try: - import gguf # type: ignore - except Exception as exc: # pragma: no cover - dependency error - raise RuntimeError("GGUF support requires the 'gguf' package to be installed.") from exc - - gguf_tensor_names = {tensor.name for tensor in reader.tensors} - - def resolve_model_type() -> str: - cfg = self.od_config.tf_model_config - model_type = None - if cfg is not None: - model_type = cfg.get("model_type") - if model_type: - return model_type - model_class = self.od_config.model_class_name or "" - if model_class.startswith("QwenImage"): - return "qwen_image" - raise ValueError("Cannot infer gguf model_type for Qwen-Image.") - - def resolve_arch(model_type: str): - for key, value in gguf.MODEL_ARCH_NAMES.items(): - if value == model_type: - return key - raise RuntimeError(f"Unknown gguf model_type: {model_type}") - - def resolve_num_layers(target_module: torch.nn.Module) -> int: - if hasattr(target_module, "transformer_blocks"): - return len(getattr(target_module, "transformer_blocks")) - cfg = self.od_config.tf_model_config - if cfg is not None: - for key in ("num_hidden_layers", "num_layers", "n_layers"): - value = cfg.get(key) - if isinstance(value, int) and value > 0: - return value - raise ValueError("Cannot infer gguf num_layers for Qwen-Image.") - - def get_target_module(root: torch.nn.Module, prefix: str) -> torch.nn.Module: - if not prefix: - return root - prefix = prefix.rstrip(".") - if hasattr(root, "get_submodule"): - return root.get_submodule(prefix) - current = root - for part in prefix.split("."): - current = getattr(current, part) - return current - - def split_name(name: str) -> tuple[str, str]: - if name.endswith("_weight"): - return name[:-7], "weight" - if "." in name: - base, suffix = name.rsplit(".", 1) - return base, suffix - return name, "" - - model_type = resolve_model_type() - try: - arch = resolve_arch(model_type) - except RuntimeError: - # Fallback: some gguf versions may not register qwen_image arch. - # In that case, rely on direct tensor names from the GGUF file. - return {} - target_module = get_target_module(self.model, self.source.prefix) - num_layers = resolve_num_layers(target_module) - name_map = gguf.get_tensor_name_map(arch, num_layers) - - candidate_names = {name for name, _ in target_module.named_parameters()} - candidate_names.update(name for name, _ in target_module.named_buffers()) - for name in list(candidate_names): - if ".to_qkv." in name: - candidate_names.add(name.replace(".to_qkv.", ".to_q.")) - candidate_names.add(name.replace(".to_qkv.", ".to_k.")) - candidate_names.add(name.replace(".to_qkv.", ".to_v.")) - if ".add_kv_proj." in name: - candidate_names.add(name.replace(".add_kv_proj.", ".add_q_proj.")) - candidate_names.add(name.replace(".add_kv_proj.", ".add_k_proj.")) - candidate_names.add(name.replace(".add_kv_proj.", ".add_v_proj.")) - if ".to_out." in name: - candidate_names.add(name.replace(".to_out.", ".to_out.0.")) - - gguf_to_model_map: dict[str, str] = {} - for name in candidate_names: - base_name, suffix = split_name(name) - gguf_base = name_map.get_name(base_name) - if gguf_base is None: - continue - candidates = [] - if suffix: - candidates.append(f"{gguf_base}.{suffix}") - if suffix == "weight": - candidates.append(f"{gguf_base}.scale") - else: - candidates.append(gguf_base) - gguf_name = next((c for c in candidates if c in gguf_tensor_names), None) - if gguf_name is None: - continue - gguf_to_model_map[gguf_name] = name - - return gguf_to_model_map From 9c0a8f093dfe2a4df81cbaaed8d115356b282b9c Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 25 Feb 2026 16:48:20 +0800 Subject: [PATCH 53/62] remove qwen-image Signed-off-by: David Chen <530634352@qq.com> --- docs/user_guide/diffusion/quantization/gguf.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/gguf.md b/docs/user_guide/diffusion/quantization/gguf.md index bef992f5ec..61d1d520fc 100644 --- a/docs/user_guide/diffusion/quantization/gguf.md +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -6,7 +6,7 @@ 3. Keep user-facing knobs minimal and consistent across offline and online flows. ## Scope -1. Models: Qwen-Image, Z-Image, and Flux2-klein. +1. Models: Z-Image, and Flux2-klein. 2. Components: diffusion transformer weights, loader paths, and quantization configs. 3. Modes: native GGUF (transformer-only weights). @@ -112,19 +112,17 @@ x @ weight.T 1. `DiffusersPipelineLoader.load_model` detects `quantization_config.method == "gguf"`. 2. `gguf_model` is resolved as one of: local file, URL, `repo/file.gguf`, or `repo:quant_type`. 3. GGUF weights are routed through adapters in `vllm_omni/diffusion/model_loader/gguf_adapters/`. -4. Name mapping is applied per-architecture (Qwen-Image, Z-Image, Flux2Klein). +4. Name mapping is applied per-architecture (Z-Image, Flux2Klein). 5. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. ## GGUF Adapter Design 1. `GGUFAdapter` (base) implements default gguf-py tensor name mapping. 2. `Flux2KleinGGUFAdapter` implements Flux2-Klein remapping + qkv split + adaLN swap. -3. `QwenImageGGUFAdapter` implements Qwen-Image qkv shard handling and linear qweight routing. -4. `ZImageGGUFAdapter` implements Z-Image qkv + ffn shard handling and linear qweight routing. -5. `get_gguf_adapter(...)` selects the adapter by model class/config and returns an iterator of `(name, tensor)`. +3. `ZImageGGUFAdapter` implements Z-Image qkv + ffn shard handling and linear qweight routing. +4. `get_gguf_adapter(...)` selects the adapter by model class/config and returns an iterator of `(name, tensor)`. Adapter paths: - Base: `vllm_omni/diffusion/model_loader/gguf_adapters/base.py` -- Qwen-Image: `vllm_omni/diffusion/model_loader/gguf_adapters/qwen_image.py` - Z-Image: `vllm_omni/diffusion/model_loader/gguf_adapters/z_image.py` - Flux2-Klein: `vllm_omni/diffusion/model_loader/gguf_adapters/flux2_klein.py` From 3f08d842855cbfd7e7f91c88d588d7f99922379b Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Wed, 25 Feb 2026 16:49:51 +0800 Subject: [PATCH 54/62] remove qwen-image Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py index 770158de78..03b2412f6e 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -7,7 +7,6 @@ from .base import GGUFAdapter from .flux2_klein import Flux2KleinGGUFAdapter -from .qwen_image import QwenImageGGUFAdapter from .z_image import ZImageGGUFAdapter if TYPE_CHECKING: @@ -23,7 +22,7 @@ def get_gguf_adapter( source: DiffusersPipelineLoader.ComponentSource, od_config: OmniDiffusionConfig, ) -> GGUFAdapter: - for adapter_cls in (QwenImageGGUFAdapter, ZImageGGUFAdapter, Flux2KleinGGUFAdapter): + for adapter_cls in (ZImageGGUFAdapter, Flux2KleinGGUFAdapter): if adapter_cls.is_compatible(od_config, model, source): return adapter_cls(gguf_file, model, source, od_config) return GGUFAdapter(gguf_file, model, source, od_config) @@ -32,7 +31,6 @@ def get_gguf_adapter( __all__ = [ "GGUFAdapter", "Flux2KleinGGUFAdapter", - "QwenImageGGUFAdapter", "ZImageGGUFAdapter", "get_gguf_adapter", ] From 5205599df4f5223376fc43acb64f2ea7b903ac3f Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 26 Feb 2026 11:32:30 +0800 Subject: [PATCH 55/62] fix flux2 Signed-off-by: David Chen <530634352@qq.com> --- .../flux2_klein/flux2_klein_transformer.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py index ff9c79e167..5ee5ee440a 100644 --- a/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py +++ b/vllm_omni/diffusion/models/flux2_klein/flux2_klein_transformer.py @@ -771,18 +771,32 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loaded_params: set[str] = set() for name, loaded_weight in weights: + original_name = name + mapped = False for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in original_name: continue - name = name.replace(weight_name, param_name) - param = params_dict[name] + name = original_name.replace(weight_name, param_name) + param = params_dict.get(name) + if param is None: + break weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) + mapped = True break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) + if mapped: + continue + + name = original_name + if name not in params_dict and ".to_out.0." in name: + name = name.replace(".to_out.0.", ".to_out.") + # Some GGUF checkpoints include quantized tensors for modules that + # are intentionally left unquantized in this model. + param = params_dict.get(name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) return loaded_params From 740aaab8e5d37825bd904e03441411767f200dcc Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 26 Feb 2026 14:29:41 +0800 Subject: [PATCH 56/62] fix ci Signed-off-by: David Chen <530634352@qq.com> --- docs/mkdocs/hooks/generate_argparse.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py index b4eb441292..6cef7cfbd2 100644 --- a/docs/mkdocs/hooks/generate_argparse.py +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -124,6 +124,7 @@ def add_parser(self, name, **kwargs): "logger": logger, "DummySubparsers": DummySubparsers, "argparse": __import__("argparse"), + "json": __import__("json"), "DESCRIPTION": DESCRIPTION, } exec(code, exec_globals, local_vars) From 11de53eeb1e06fbb620f00e52d1225916c5d7494 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 26 Feb 2026 15:44:56 +0800 Subject: [PATCH 57/62] fix comment 1 Signed-off-by: David Chen <530634352@qq.com> --- .../user_guide/diffusion/quantization/gguf.md | 4 +- .../model_loader/diffusers_loader.py | 56 +++---------------- 2 files changed, 11 insertions(+), 49 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/gguf.md b/docs/user_guide/diffusion/quantization/gguf.md index 61d1d520fc..e76f717152 100644 --- a/docs/user_guide/diffusion/quantization/gguf.md +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -110,7 +110,7 @@ x @ weight.T ## GGUF Weight Loading Path (Transformer-Only) 1. `DiffusersPipelineLoader.load_model` detects `quantization_config.method == "gguf"`. -2. `gguf_model` is resolved as one of: local file, URL, `repo/file.gguf`, or `repo:quant_type`. +2. `gguf_model` is resolved as one of: local file, `repo/file.gguf`, or `repo:quant_type`. 3. GGUF weights are routed through adapters in `vllm_omni/diffusion/model_loader/gguf_adapters/`. 4. Name mapping is applied per-architecture (Z-Image, Flux2Klein). 5. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. @@ -160,7 +160,7 @@ python examples/offline_inference/text_to_image/text_to_image.py \ Notes for GGUF: 1. Many GGUF repos do not ship `model_index.json` and configs. Use the base repo for `--model` and only pass the GGUF file via `--gguf-model`. -2. `gguf_model` supports local path, URL, `repo/file.gguf`, or `repo:quant_type`. +2. `gguf_model` supports local path, `repo/file.gguf`, or `repo:quant_type`. ## User Usage (Online) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index e62993ffe7..276bdf4458 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -3,14 +3,10 @@ import dataclasses import glob import os -import shutil -import tempfile import time from collections.abc import Generator, Iterable from pathlib import Path from typing import cast -from urllib.parse import urlparse -from urllib.request import urlopen import torch from huggingface_hub import hf_hub_download @@ -328,9 +324,6 @@ def _get_model_loadable_names(self, model: nn.Module) -> set[str]: def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str: if os.path.isfile(gguf_model): return gguf_model - # raw HTTPS link - if gguf_model.startswith(("http://", "https://")) and gguf_model.endswith(".gguf"): - return self._download_raw_gguf_url(gguf_model) # repo_id/filename.gguf if "/" in gguf_model and gguf_model.endswith(".gguf"): repo_id, filename = gguf_model.rsplit("/", 1) @@ -352,46 +345,9 @@ def _resolve_gguf_model_path(self, gguf_model: str, revision: str | None) -> str ) raise ValueError( f"Unrecognized GGUF reference: {gguf_model!r} (expected local file, " - "raw URL, /.gguf, or :)" + "/.gguf, or :)" ) - def _download_raw_gguf_url(self, url: str) -> str: - parsed = urlparse(url) - filename = os.path.basename(parsed.path) - if not filename: - raise ValueError(f"Cannot infer GGUF filename from URL: {url!r}") - - cache_dir = self.load_config.download_dir - if cache_dir is None: - cache_dir = os.path.join( - os.path.expanduser("~"), - ".cache", - "vllm-omni", - "gguf", - ) - os.makedirs(cache_dir, exist_ok=True) - target_path = os.path.join(cache_dir, filename) - if os.path.exists(target_path): - return target_path - - tmp_fd, tmp_path = tempfile.mkstemp( - suffix=".gguf", - prefix="gguf-", - dir=cache_dir, - ) - os.close(tmp_fd) - try: - with urlopen(url, timeout=300) as response, open(tmp_path, "wb") as out_file: - shutil.copyfileobj(response, out_file) - os.replace(tmp_path, target_path) - except Exception: - try: - os.remove(tmp_path) - except OSError: - pass - raise - return target_path - def _get_gguf_weights_iterator( self, source: "ComponentSource", @@ -419,9 +375,15 @@ def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConf if self._is_transformer_source(source): loaded |= model.load_weights(self._get_gguf_weights_iterator(source, model, od_config)) - # Load any remaining float weights (e.g., non-quantized layers) - # from the base HF checkpoint while skipping already-loaded names. + # GGUF checkpoints can be transformer-only or partially quantized. + # Only fall back to HF if this source still has missing loadable weights. loadable_names = loadable_names or self._get_model_loadable_names(model) + has_missing_for_source = any( + name.startswith(source.prefix) and name not in loaded for name in loadable_names + ) + if not has_missing_for_source: + continue + hf_iter = self._get_weights_iterator(source) hf_iter = ( (name, tensor) for (name, tensor) in hf_iter if name in loadable_names and name not in loaded From 3b281e8c99dd035dfbc38683b40fe9544016bf04 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 26 Feb 2026 16:32:52 +0800 Subject: [PATCH 58/62] fix comment 2 Signed-off-by: David Chen <530634352@qq.com> --- docs/user_guide/diffusion/quantization/gguf.md | 4 ++-- .../model_loader/gguf_adapters/__init__.py | 13 +++++++++++-- .../diffusion/model_loader/gguf_adapters/base.py | 11 ++++++----- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/docs/user_guide/diffusion/quantization/gguf.md b/docs/user_guide/diffusion/quantization/gguf.md index e76f717152..025180e185 100644 --- a/docs/user_guide/diffusion/quantization/gguf.md +++ b/docs/user_guide/diffusion/quantization/gguf.md @@ -116,10 +116,10 @@ x @ weight.T 5. GGUF weights are loaded into transformer modules, remaining non-transformer weights come from the HF checkpoint. ## GGUF Adapter Design -1. `GGUFAdapter` (base) implements default gguf-py tensor name mapping. +1. `GGUFAdapter` is an abstract base class for model-specific adapters. 2. `Flux2KleinGGUFAdapter` implements Flux2-Klein remapping + qkv split + adaLN swap. 3. `ZImageGGUFAdapter` implements Z-Image qkv + ffn shard handling and linear qweight routing. -4. `get_gguf_adapter(...)` selects the adapter by model class/config and returns an iterator of `(name, tensor)`. +4. `get_gguf_adapter(...)` strictly selects by model class/config; unsupported models raise an error (no fallback adapter). Adapter paths: - Base: `vllm_omni/diffusion/model_loader/gguf_adapters/base.py` diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py index 03b2412f6e..416ebc7a84 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/__init__.py @@ -22,10 +22,19 @@ def get_gguf_adapter( source: DiffusersPipelineLoader.ComponentSource, od_config: OmniDiffusionConfig, ) -> GGUFAdapter: - for adapter_cls in (ZImageGGUFAdapter, Flux2KleinGGUFAdapter): + adapter_classes = (ZImageGGUFAdapter, Flux2KleinGGUFAdapter) + for adapter_cls in adapter_classes: if adapter_cls.is_compatible(od_config, model, source): return adapter_cls(gguf_file, model, source, od_config) - return GGUFAdapter(gguf_file, model, source, od_config) + model_type = None + if od_config.tf_model_config is not None: + model_type = od_config.tf_model_config.get("model_type") + supported = ", ".join(cls.__name__ for cls in adapter_classes) + raise ValueError( + "No GGUF adapter matched diffusion model " + f"(model_class_name={od_config.model_class_name!r}, model_type={model_type!r}). " + f"Supported adapters: {supported}." + ) __all__ = [ diff --git a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py index 0eb55c43ca..8794ecff73 100644 --- a/vllm_omni/diffusion/model_loader/gguf_adapters/base.py +++ b/vllm_omni/diffusion/model_loader/gguf_adapters/base.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Generator from dataclasses import dataclass from typing import Any @@ -19,8 +20,8 @@ class MappedTensor: swap_scale_shift: bool = False -class GGUFAdapter: - """Default GGUF adapter using gguf-py's tensor name mapping.""" +class GGUFAdapter(ABC): + """Base class for model-specific GGUF adapters.""" _include_qkv_virtuals: bool = False _include_add_kv_proj_virtuals: bool = False @@ -37,11 +38,11 @@ def __init__(self, gguf_file: str, model: torch.nn.Module, source, od_config) -> @staticmethod def is_compatible(od_config, model: torch.nn.Module, source) -> bool: - # Default adapter matches any model. - return True + return False + @abstractmethod def weights_iterator(self) -> Generator[tuple[str, torch.Tensor], None, None]: - return gguf_quant_weights_iterator(self.gguf_file) + raise NotImplementedError def _get_target_module(self) -> torch.nn.Module: prefix = getattr(self.source, "prefix", "") From ea6917628a079e5066afd6cf8c68ac7cab30e738 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 26 Feb 2026 19:37:03 +0800 Subject: [PATCH 59/62] fix ci Signed-off-by: David Chen <530634352@qq.com> --- tests/diffusion/test_diffusers_loader.py | 73 +++++++++++++++++++ .../model_loader/diffusers_loader.py | 41 ++++++++--- 2 files changed, 104 insertions(+), 10 deletions(-) create mode 100644 tests/diffusion/test_diffusers_loader.py diff --git a/tests/diffusion/test_diffusers_loader.py b/tests/diffusion/test_diffusers_loader.py new file mode 100644 index 0000000000..3f63960274 --- /dev/null +++ b/tests/diffusion/test_diffusers_loader.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +class _DummyPipelineModel(nn.Module): + def __init__(self, *, source_prefix: str): + super().__init__() + self.transformer = nn.Linear(2, 2, bias=False) + self.vae = nn.Linear(2, 2, bias=False) + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path="dummy", + subfolder="transformer", + revision=None, + prefix=source_prefix, + fall_back_to_pt=True, + ) + ] + + def load_weights(self, weights): + params = dict(self.named_parameters()) + loaded: set[str] = set() + for name, tensor in weights: + if name not in params: + continue + params[name].data.copy_(tensor.to(dtype=params[name].dtype)) + loaded.add(name) + return loaded + + +def _make_loader_with_weights(weight_names: list[str]) -> DiffusersPipelineLoader: + loader = object.__new__(DiffusersPipelineLoader) + loader.counter_before_loading_weights = 0.0 + loader.counter_after_loading_weights = 0.0 + + def _iter_weights(_model): + for name in weight_names: + yield name, torch.zeros((2, 2)) + + loader.get_all_weights = _iter_weights # type: ignore[assignment] + return loader + + +def test_strict_check_only_validates_source_prefix_parameters(): + model = _DummyPipelineModel(source_prefix="transformer.") + loader = _make_loader_with_weights(["transformer.weight"]) + + # Should not require VAE parameters because they are outside weights_sources. + loader.load_weights(model) + + +def test_strict_check_raises_when_source_parameters_are_missing(): + model = _DummyPipelineModel(source_prefix="transformer.") + loader = _make_loader_with_weights([]) + + with pytest.raises(ValueError, match="transformer.weight"): + loader.load_weights(model) + + +def test_empty_source_prefix_keeps_full_model_strict_check(): + model = _DummyPipelineModel(source_prefix="") + loader = _make_loader_with_weights(["transformer.weight"]) + + with pytest.raises(ValueError, match="vae.weight"): + loader.load_weights(model) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 276bdf4458..72a73ccf35 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -196,13 +196,36 @@ def get_all_weights( self, model: nn.Module, ) -> Generator[tuple[str, torch.Tensor], None, None]: - sources = cast( - Iterable[DiffusersPipelineLoader.ComponentSource], - getattr(model, "weights_sources", ()), - ) + sources = self._get_weight_sources(model) for source in sources: yield from self._get_weights_iterator(source) + def _get_weight_sources(self, model: nn.Module) -> tuple["ComponentSource", ...]: + return tuple( + cast( + Iterable[DiffusersPipelineLoader.ComponentSource], + getattr(model, "weights_sources", ()), + ) + ) + + def _get_expected_parameter_names(self, model: nn.Module) -> set[str]: + """Return parameter names that should be covered by strict load checks.""" + all_parameter_names = {name for name, _ in model.named_parameters()} + sources = self._get_weight_sources(model) + + # Keep strict behavior if no source metadata exists. + if not sources: + return all_parameter_names + + # Empty prefix means "root" source, i.e. entire model should be covered. + if any(source.prefix == "" for source in sources): + return all_parameter_names + + source_prefixes = tuple(source.prefix for source in sources if source.prefix) + if not source_prefixes: + return all_parameter_names + return {name for name in all_parameter_names if name.startswith(source_prefixes)} + def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights( model_name_or_path=model_config.model, @@ -265,7 +288,7 @@ def _process_weights_after_loading(self, model: nn.Module, target_device: torch. module.to(module_device) def load_weights(self, model: nn.Module) -> None: - weights_to_load = {name for name, _ in model.named_parameters()} + weights_to_load = self._get_expected_parameter_names(model) loaded_weights = model.load_weights(self.get_all_weights(model)) self.counter_after_loading_weights = time.perf_counter() @@ -364,10 +387,7 @@ def _get_gguf_weights_iterator( return ((source.prefix + name, tensor) for (name, tensor) in weights_iter) def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConfig) -> set[str]: - sources = cast( - Iterable[DiffusersPipelineLoader.ComponentSource], - getattr(model, "weights_sources", ()), - ) + sources = self._get_weight_sources(model) loaded: set[str] = set() loadable_names: set[str] | None = None @@ -392,7 +412,8 @@ def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConf else: loaded |= model.load_weights(self._get_weights_iterator(source)) - weights_to_load = {name for name, _ in model.named_parameters()} + weights_to_load = self._get_expected_parameter_names(model) weights_not_loaded = weights_to_load - loaded if weights_not_loaded: raise ValueError(f"Following weights were not initialized from checkpoint: {weights_not_loaded}") + return loaded From e996f821bb773a5f60b317ed52fb4a90517bfade Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Thu, 26 Feb 2026 20:17:52 +0800 Subject: [PATCH 60/62] fix ci 2 Signed-off-by: David Chen <530634352@qq.com> --- .../diffusion/models/stable_audio/pipeline_stable_audio.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index c48d68efd6..22dfc06c5a 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -30,6 +30,7 @@ from vllm_omni.diffusion.models.interface import SupportAudioOutput from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel from vllm_omni.diffusion.request import OmniDiffusionRequest +from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs logger = init_logger(__name__) @@ -127,8 +128,9 @@ def __init__( local_files_only=local_files_only, ).to(self.device) - # Initialize our custom transformer (weights loaded via load_weights) - self.transformer = StableAudioDiTModel(od_config=od_config) + # Initialize transformer from HF config to keep architecture aligned with checkpoint. + transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, StableAudioDiTModel) + self.transformer = StableAudioDiTModel(od_config=od_config, **transformer_kwargs) # Load scheduler self.scheduler = CosineDPMSolverMultistepScheduler.from_pretrained( From fe9457c45ff35ce3d7dca75f12afc7d4d7134352 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 27 Feb 2026 09:07:22 +0800 Subject: [PATCH 61/62] fix pre-commit Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/diffusers_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index e28e426779..adb68c6df1 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -28,8 +28,8 @@ from vllm.utils.torch_utils import set_default_torch_dtype from vllm_omni.diffusion.data import OmniDiffusionConfig -from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter from vllm_omni.diffusion.distributed.hsdp import HSDPInferenceConfig +from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter from vllm_omni.diffusion.registry import initialize_model logger = init_logger(__name__) @@ -420,7 +420,7 @@ def _load_weights_with_gguf(self, model: nn.Module, od_config: OmniDiffusionConf if weights_not_loaded: raise ValueError(f"Following weights were not initialized from checkpoint: {weights_not_loaded}") return loaded - + def _load_model_with_hsdp( self, od_config: OmniDiffusionConfig, From ad543fe6dac33c4304cdc38afd12a8fd8220f7b6 Mon Sep 17 00:00:00 2001 From: David Chen <530634352@qq.com> Date: Fri, 27 Feb 2026 09:56:49 +0800 Subject: [PATCH 62/62] fix wrong merge Signed-off-by: David Chen <530634352@qq.com> --- vllm_omni/diffusion/model_loader/diffusers_loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index adb68c6df1..c4271c3383 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -258,8 +258,11 @@ def load_model( model_cls = resolve_obj_by_qualname(custom_pipeline_name) model = model_cls(od_config=od_config) logger.debug("Loading weights on %s ...", load_device) - # Quantization does not happen in `load_weights` but after it - self.load_weights(model) + if self._is_gguf_quantization(od_config): + self._load_weights_with_gguf(model, od_config) + else: + # Quantization does not happen in `load_weights` but after it + self.load_weights(model) # Process weights after loading for quantization (e.g., FP8 online quantization) # This is needed for vLLM's quantization methods that need to transform weights