diff --git a/docs/.nav.yml b/docs/.nav.yml index ae4a58f5c7..07db1b4651 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -43,6 +43,9 @@ nav: - Overview: user_guide/diffusion_acceleration.md - TeaCache: user_guide/diffusion/teacache.md - Cache-DiT: user_guide/diffusion/cache_dit_acceleration.md + - Quantization: + - Overview: user_guide/diffusion/quantization/overview.md + - FP8: user_guide/diffusion/quantization/fp8.md - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md - ComfyUI: features/comfyui.md diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md new file mode 100644 index 0000000000..2c065546bf --- /dev/null +++ b/docs/user_guide/diffusion/quantization/fp8.md @@ -0,0 +1,77 @@ +# FP8 Quantization + +## Overview + +FP8 quantization converts BF16/FP16 weights to FP8 at model load time. No calibration or pre-quantized checkpoint needed. + +Depending on the model, either all layers can be quantized, or some sensitive layers should stay in BF16. See the [per-model table](#supported-models) for which case applies. + +Common sensitive layers in DiT-based diffusion models include **image-stream MLPs** (`img_mlp`). These are particularly vulnerable to FP8 precision loss because they process denoising latents whose dynamic range shifts significantly across timesteps, and unlike attention projections (which benefit from QK-Norm stabilization), MLPs have no built-in normalization to absorb quantization error. In deep architectures (e.g., 60+ residual blocks), small per-layer errors compound and degrade output quality. Other layers such as **attention projections** (`to_qkv`, `to_out`) and **text-stream MLPs** (`txt_mlp`) are generally more robust due to normalization or more stable input statistics. + +## Configuration + +1. **Python API**: set `quantization="fp8"`. To skip sensitive layers, use `quantization_config` with `ignored_layers`. + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# All layers quantized +omni = Omni(model="", quantization="fp8") + +# Skip sensitive layers +omni = Omni( + model="", + quantization_config={ + "method": "fp8", + "ignored_layers": [""], + }, +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + +2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. + +```bash +# All layers +python text_to_image.py --model --quantization fp8 + +# Skip sensitive layers +python text_to_image.py --model --quantization fp8 --ignored-layers "img_mlp" + +# Online serving +vllm serve --omni --quantization fp8 +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `method` | str | — | Quantization method (`"fp8"`) | +| `ignored_layers` | list[str] | `[]` | Layer name patterns to keep in BF16 | +| `activation_scheme` | str | `"dynamic"` | `"dynamic"` (no calibration) or `"static"` | +| `weight_block_size` | list[int] \| None | `None` | Block size for block-wise weight quantization | + +The available `ignored_layers` names depend on the model architecture (e.g., `to_qkv`, `to_out`, `img_mlp`, `txt_mlp`). Consult the transformer source for your target model. + +## Supported Models + +| Model | HF Models | Recommendation | `ignored_layers` | +|-------|-----------|---------------|------------------| +| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | All layers | None | +| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Skip sensitive layers | `img_mlp` | + +## Combining with Other Features + +FP8 quantization can be combined with cache acceleration: + +```python +omni = Omni( + model="", + quantization="fp8", + cache_backend="tea_cache", + cache_config={"rel_l1_thresh": 0.2}, +) +``` diff --git a/docs/user_guide/diffusion/quantization/overview.md b/docs/user_guide/diffusion/quantization/overview.md new file mode 100644 index 0000000000..7dede292fc --- /dev/null +++ b/docs/user_guide/diffusion/quantization/overview.md @@ -0,0 +1,17 @@ +# Quantization for Diffusion Transformers + +vLLM-Omni supports quantization of DiT linear layers to reduce memory usage and accelerate inference. + +## Supported Methods + +| Method | Guide | +|--------|-------| +| FP8 | [FP8](fp8.md) | + +## Device Compatibility + +| GPU Generation | Example GPUs | FP8 Mode | +|---------------|-------------------|----------| +| Ada/Hopper (SM 89+) | RTX 4090, H100, H200 | Full W8A8 with native hardware | + +Kernel selection is automatic. diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index 7856ea9606..1a2e0a7d23 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -1,6 +1,6 @@ # Diffusion Acceleration Overview -vLLM-Omni supports various cache acceleration methods to speed up diffusion model inference with minimal quality degradation. These methods include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, and **parallelism methods** that distribute the computation across multiple devices. +vLLM-Omni supports various acceleration methods to speed up diffusion model inference with minimal quality degradation. These include **cache methods** that intelligently cache intermediate computations to avoid redundant work across diffusion timesteps, **parallelism methods** that distribute the computation across multiple devices, and **quantization methods** that reduce memory footprint while preserving accuracy. ## Supported Acceleration Methods @@ -14,6 +14,10 @@ vLLM-Omni currently supports two main cache acceleration backends: Both methods can provide significant speedups (typically **1.5x-2.0x**) while maintaining high output quality. +vLLM-Omni also supports quantization methods: + +3. **[FP8 Quantization](diffusion/quantization/overview.md)** - Reduces DiT linear layers from BF16 to FP8, providing ~1.28x speedup with minimal quality loss. Supports per-layer skip for sensitive layers. + vLLM-Omni also supports parallelism methods for diffusion models, including: 1. [Ulysses-SP](diffusion/parallelism_acceleration.md#ulysses-sp) - splits the input along the sequence dimension and uses all-to-all communication to allow each device to compute only a subset of attention heads. @@ -35,6 +39,12 @@ vLLM-Omni also supports parallelism methods for diffusion models, including: | **TeaCache** | `cache_backend="tea_cache"` | Simple, adaptive caching with minimal configuration | Quick setup, balanced speed/quality | | **Cache-DiT** | `cache_backend="cache_dit"` | Advanced caching with multiple techniques (DBCache, TaylorSeer, SCM) | Maximum acceleration, fine-grained control | +### Quantization Methods + +| Method | Configuration | Description | Best For | +|--------|--------------|-------------|----------| +| **FP8** | `quantization="fp8"` | FP8 W8A8 on Ada/Hopper, weight-only on older GPUs | Memory reduction, inference speedup | + ## Supported Models The following table shows which models are currently supported by each acceleration method: @@ -58,10 +68,18 @@ The following table shows which models are currently supported by each accelerat ### VideoGen -| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention |CFG-Parallel | +| Model | Model Identifier | TeaCache | Cache-DiT | Ulysses-SP | Ring-Attention | CFG-Parallel | |-------|------------------|:--------:|:---------:|:----------:|:--------------:|:----------------:| | **Wan2.2** | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ❌ | ✅ | ✅ | ✅ | ✅ | +### Quantization + +| Model | Model Identifier | FP8 | +|-------|------------------|:---:| +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | +| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | + ## Performance Benchmarks @@ -272,12 +290,30 @@ outputs = omni.generate( ) ``` +### Using FP8 Quantization + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="", + quantization="fp8", +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + ## Documentation For detailed information on each acceleration method: - **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices - **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters +- **[FP8 Quantization Guide](diffusion/quantization/overview.md)** - FP8 quantization for DiT models with per-layer control - **[Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism)** - Guidance on how to enable TP for diffusion models. - **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. - **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 66ae9c102d..5da0c18799 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -5,6 +5,7 @@ import os import time from pathlib import Path +from typing import Any import torch @@ -118,6 +119,24 @@ def parse_args() -> argparse.Namespace: default=1, help="Number of ready layers (blocks) to keep on GPU during generation.", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + choices=["fp8"], + help="Quantization method for the transformer. " + "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). " + "Default: None (no quantization, uses BF16).", + ) + parser.add_argument( + "--ignored-layers", + type=str, + default=None, + help="Comma-separated list of layer name patterns to skip quantization. " + "Only used when --quantization is set. " + "Available layers: to_qkv, to_out, add_kv_proj, to_add_out, img_mlp, txt_mlp, proj_out. " + "Example: --ignored-layers 'add_kv_proj,to_add_out'", + ) parser.add_argument( "--vae-use-slicing", action="store_true", @@ -188,6 +207,18 @@ def main(): # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + # Build quantization kwargs: use quantization_config dict when + # ignored_layers is specified so the list flows through OmniDiffusionConfig + quant_kwargs: dict[str, Any] = {} + ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None + if args.quantization and ignored_layers: + quant_kwargs["quantization_config"] = { + "method": args.quantization, + "ignored_layers": ignored_layers, + } + elif args.quantization: + quant_kwargs["quantization"] = args.quantization + omni = Omni( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, @@ -200,6 +231,7 @@ def main(): parallel_config=parallel_config, enforce_eager=args.enforce_eager, enable_cpu_offload=args.enable_cpu_offload, + **quant_kwargs, ) if profiler_enabled: @@ -212,6 +244,9 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {args.cache_backend if args.cache_backend else 'None (no acceleration)'}") + print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}") + if ignored_layers: + print(f" Ignored layers: {ignored_layers}") print( f" Parallel configuration: tensor_parallel_size={args.tensor_parallel_size}, " f"ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, " diff --git a/mkdocs.yml b/mkdocs.yml index d3fc520737..4f554003a9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -94,6 +94,7 @@ plugins: exclude: - "re:vllm_omni\\._.*" # Internal modules - "vllm_omni.diffusion.models.qwen_image" # avoid importing vllm in mkdocs building + - "vllm_omni.diffusion.quantization" # avoid importing vllm in mkdocs building - "vllm_omni.entrypoints.async_diffusion" # avoid importing vllm in mkdocs building - "vllm_omni.entrypoints.openai" # avoid importing vllm in mkdocs building - "vllm_omni.entrypoints.openai.protocol" # avoid importing vllm in mkdocs building diff --git a/tests/diffusion/quantization/__init__.py b/tests/diffusion/quantization/__init__.py new file mode 100644 index 0000000000..208f01a7cb --- /dev/null +++ b/tests/diffusion/quantization/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/diffusion/quantization/test_fp8_config.py b/tests/diffusion/quantization/test_fp8_config.py new file mode 100644 index 0000000000..57661d2990 --- /dev/null +++ b/tests/diffusion/quantization/test_fp8_config.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for FP8 quantization config.""" + +import pytest + + +def test_fp8_config_creation(): + """Test that FP8 config can be created.""" + from vllm_omni.diffusion.quantization import get_diffusion_quant_config + + config = get_diffusion_quant_config("fp8") + assert config is not None + assert config.get_name() == "fp8" + + +def test_vllm_config_extraction(): + """Test that vLLM config can be extracted from diffusion config.""" + from vllm_omni.diffusion.quantization import ( + get_diffusion_quant_config, + get_vllm_quant_config_for_layers, + ) + + diff_config = get_diffusion_quant_config("fp8") + vllm_config = get_vllm_quant_config_for_layers(diff_config) + assert vllm_config is not None + assert vllm_config.activation_scheme == "dynamic" + + +def test_none_quantization(): + """Test that None quantization returns None config.""" + from vllm_omni.diffusion.quantization import ( + get_diffusion_quant_config, + get_vllm_quant_config_for_layers, + ) + + config = get_diffusion_quant_config(None) + assert config is None + vllm_config = get_vllm_quant_config_for_layers(config) + assert vllm_config is None + + +def test_invalid_quantization(): + """Test that invalid quantization method raises error.""" + from vllm_omni.diffusion.quantization import get_diffusion_quant_config + + with pytest.raises(ValueError, match="Unknown quantization method"): + get_diffusion_quant_config("invalid_method") + + +def test_fp8_config_with_custom_params(): + """Test FP8 config with custom parameters.""" + from vllm_omni.diffusion.quantization import get_diffusion_quant_config + + config = get_diffusion_quant_config( + "fp8", + activation_scheme="static", + ignored_layers=["proj_out"], + ) + assert config is not None + assert config.activation_scheme == "static" + assert "proj_out" in config.ignored_layers + + +def test_supported_methods(): + """Test that supported methods list is correct.""" + from vllm_omni.diffusion.quantization import SUPPORTED_QUANTIZATION_METHODS + + assert "fp8" in SUPPORTED_QUANTIZATION_METHODS + + +def test_quantization_integration(): + """Test end-to-end quantization flow through OmniDiffusionConfig.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + # Test with quantization string only + config = OmniDiffusionConfig(model="test", quantization="fp8") + assert config.quantization_config is not None + assert config.quantization_config.get_name() == "fp8" + + # Test with quantization_config dict + config2 = OmniDiffusionConfig( + model="test", + quantization_config={"method": "fp8", "activation_scheme": "static"}, + ) + assert config2.quantization_config is not None + assert config2.quantization_config.get_name() == "fp8" + assert config2.quantization_config.activation_scheme == "static" + + # Test that vLLM config can be extracted + vllm_config = config.quantization_config.get_vllm_quant_config() + assert vllm_config is not None + + +def test_quantization_dict_not_mutated(): + """Test that passing a dict to quantization_config doesn't mutate it.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + original_dict = {"method": "fp8", "activation_scheme": "static"} + dict_copy = original_dict.copy() + + OmniDiffusionConfig(model="test", quantization_config=original_dict) + + # Original dict should be unchanged + assert original_dict == dict_copy + + +def test_quantization_conflicting_methods_warning(caplog): + """Test warning when quantization and quantization_config['method'] conflict.""" + import logging + + from vllm_omni.diffusion.data import OmniDiffusionConfig + + with caplog.at_level(logging.WARNING): + config = OmniDiffusionConfig( + model="test", + quantization="fp8", # This should be overridden + quantization_config={"method": "fp8", "activation_scheme": "static"}, + ) + # No warning when methods match + assert config.quantization_config is not None + + +def test_fp8_delegates_to_vllm_config(): + """Test that DiffusionFp8Config delegates to vLLM's Fp8Config.""" + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + from vllm_omni.diffusion.quantization import DiffusionFp8Config + + # Test that quant_config_cls is set correctly + assert DiffusionFp8Config.quant_config_cls is Fp8Config + + # Test that get_name() delegates to vLLM + assert DiffusionFp8Config.get_name() == Fp8Config.get_name() + + # Test that get_min_capability() delegates to vLLM + assert DiffusionFp8Config.get_min_capability() == Fp8Config.get_min_capability() diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index a4f0ba6fa5..e55bbc6f05 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -4,7 +4,7 @@ import enum import os import random -from collections.abc import Callable +from collections.abc import Callable, Mapping from dataclasses import dataclass, field, fields from typing import Any @@ -14,6 +14,10 @@ from vllm.config.utils import config from vllm.logger import init_logger +from vllm_omni.diffusion.quantization import ( + DiffusionQuantizationConfig, + get_diffusion_quant_config, +) from vllm_omni.diffusion.utils.network_utils import is_port_available logger = init_logger(__name__) @@ -369,6 +373,11 @@ class OmniDiffusionConfig: # Omni configuration (injected from stage config) omni_kv_config: dict[str, Any] = field(default_factory=dict) + # Quantization settings + # Supported methods: "fp8" (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs) + quantization: str | None = None + quantization_config: "DiffusionQuantizationConfig | dict[str, Any] | None" = None + def settle_port(self, port: int, port_inc: int = 42, max_attempts: int = 100) -> int: """ Find an available port with retry logic. @@ -455,6 +464,33 @@ def __post_init__(self): # If it's neither dict nor DiffusionCacheConfig, convert to empty config self.cache_config = DiffusionCacheConfig() + # Convert quantization config + if self.quantization is not None or self.quantization_config is not None: + # Handle dict or DictConfig (from OmegaConf) - use Mapping for broader compatibility + if isinstance(self.quantization_config, Mapping): + # Convert DictConfig to dict if needed (OmegaConf compatibility) + config_dict = dict(self.quantization_config) + # Use get() instead of pop() to avoid mutating original dict + quant_method = config_dict.get("method", self.quantization) + # Filter out "method" key for kwargs + quant_kwargs = {k: v for k, v in config_dict.items() if k != "method"} + + # Validate conflicting methods + if self.quantization is not None and quant_method is not None and quant_method != self.quantization: + logger.warning( + f"Conflicting quantization methods: quantization={self.quantization!r}, " + f"quantization_config['method']={quant_method!r}. Using quantization_config['method']." + ) + + self.quantization_config = get_diffusion_quant_config(quant_method, **quant_kwargs) + elif self.quantization_config is None and self.quantization is not None: + self.quantization_config = get_diffusion_quant_config(self.quantization) + elif not isinstance(self.quantization_config, DiffusionQuantizationConfig): + raise TypeError( + f"quantization_config must be a DiffusionQuantizationConfig, dict, or None, " + f"got {type(self.quantization_config)!r}" + ) + if self.max_cpu_loras is None: self.max_cpu_loras = 1 elif self.max_cpu_loras < 1: diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index 892954ce9e..b61f70b697 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -13,6 +13,7 @@ from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, @@ -217,8 +218,36 @@ def load_model(self, od_config: OmniDiffusionConfig, load_device: str) -> nn.Mod logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it self.load_weights(model) + + # Process weights after loading for quantization (e.g., FP8 online quantization) + # This is needed for vLLM's quantization methods that need to transform weights + self._process_weights_after_loading(model, target_device) + return model.eval() + def _process_weights_after_loading(self, model: nn.Module, target_device: torch.device) -> None: + """Process weights after loading for quantization methods. + + This handles vLLM's quantization methods that need to process weights + after loading (e.g., FP8 online quantization from BF16/FP16 weights). + """ + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # Move module to target device for processing if needed + module_device = next(module.parameters(), None) + if module_device is not None: + module_device = module_device.device + needs_device_move = module_device != target_device + + if needs_device_move: + module.to(target_device) + + quant_method.process_weights_after_loading(module) + + if needs_device_move: + module.to(module_device) + def load_weights(self, model: nn.Module) -> None: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights(self.get_all_weights(model)) diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index d85d98b5bf..a610443041 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -33,6 +33,7 @@ from vllm_omni.diffusion.models.qwen_image.qwen_image_transformer import ( QwenImageTransformer2DModel, ) +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.diffusion.utils.tf_utils import get_transformer_config_kwargs from vllm_omni.model_executor.model_loader.weight_utils import ( @@ -272,7 +273,10 @@ def __init__( self.device ) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel) - self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs) + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.transformer = QwenImageTransformer2DModel( + od_config=od_config, quant_config=quant_config, **transformer_kwargs + ) self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) diff --git a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py index 5f2b8c68a2..2d8d49eee9 100644 --- a/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py +++ b/vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py @@ -1,10 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + from collections.abc import Iterable from functools import lru_cache from math import prod -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -23,6 +25,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + from vllm_omni.diffusion.attention.backends.abstract import ( AttentionMetadata, ) @@ -390,7 +397,16 @@ def _compute_video_freqs(self, frame, height, width, idx=0): class ColumnParallelApproxGELU(nn.Module): - def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = True): + def __init__( + self, + dim_in: int, + dim_out: int, + *, + approximate: str, + bias: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): super().__init__() self.proj = ColumnParallelLinear( dim_in, @@ -398,6 +414,8 @@ def __init__(self, dim_in: int, dim_out: int, *, approximate: str, bias: bool = bias=bias, gather_output=False, return_bias=False, + quant_config=quant_config, + prefix=prefix, ) self.approximate = approximate @@ -415,6 +433,8 @@ def __init__( activation_fn: str = "gelu-approximate", inner_dim: int | None = None, bias: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: super().__init__() @@ -424,13 +444,17 @@ def __init__( dim_out = dim_out or dim layers: list[nn.Module] = [ - ColumnParallelApproxGELU(dim, inner_dim, approximate="tanh", bias=bias), + ColumnParallelApproxGELU( + dim, inner_dim, approximate="tanh", bias=bias, quant_config=quant_config, prefix=prefix + ), nn.Identity(), # placeholder for weight loading RowParallelLinear( inner_dim, dim_out, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=prefix, ), ] @@ -456,6 +480,7 @@ def __init__( pre_only: bool = False, context_pre_only: bool = False, out_dim: int | None = None, + quant_config: QuantizationConfig | None = None, ) -> None: super().__init__() assert dim % num_heads == 0 @@ -471,6 +496,8 @@ def __init__( hidden_size=dim, head_size=self.head_dim, total_num_heads=num_heads, + quant_config=quant_config, + prefix="to_qkv", ) self.query_num_heads = self.to_qkv.num_heads self.kv_num_heads = self.to_qkv.num_kv_heads @@ -485,6 +512,8 @@ def __init__( hidden_size=added_kv_proj_dim, head_size=head_dim, total_num_heads=num_heads, + quant_config=quant_config, + prefix="add_kv_proj", ) self.add_query_num_heads = self.add_kv_proj.num_heads self.add_kv_num_heads = self.add_kv_proj.num_kv_heads @@ -496,6 +525,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix="to_add_out", ) assert not pre_only @@ -505,6 +536,8 @@ def __init__( bias=out_bias, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix="to_out", ) self.norm_added_q = RMSNorm(head_dim, eps=eps) @@ -637,6 +670,7 @@ def __init__( qk_norm: str = "rms_norm", eps: float = 1e-6, zero_cond_t: bool = False, + quant_config: QuantizationConfig | None = None, ): super().__init__() @@ -656,9 +690,10 @@ def __init__( added_kv_proj_dim=dim, context_pre_only=False, head_dim=attention_head_dim, + quant_config=quant_config, ) self.img_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.img_mlp = FeedForward(dim=dim, dim_out=dim) + self.img_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="img_mlp") # Text processing modules self.txt_mod = nn.Sequential( @@ -668,7 +703,7 @@ def __init__( self.txt_norm1 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) # Text doesn't need separate attention - it's handled by img_attn joint computation self.txt_norm2 = AdaLayerNorm(dim, elementwise_affine=False, eps=eps) - self.txt_mlp = FeedForward(dim=dim, dim_out=dim) + self.txt_mlp = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="txt_mlp") self.zero_cond_t = zero_cond_t @@ -862,6 +897,7 @@ def __init__( zero_cond_t: bool = False, use_additional_t_cond: bool = False, use_layer3d_rope: bool = False, + quant_config: QuantizationConfig | None = None, ): super().__init__() self.parallel_config = od_config.parallel_config @@ -891,6 +927,7 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, zero_cond_t=zero_cond_t, + quant_config=quant_config, ) for _ in range(num_layers) ] diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index 2c938f39bf..8a39131953 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -37,6 +37,7 @@ from vllm_omni.diffusion.models.z_image.z_image_transformer import ( ZImageTransformer2DModel, ) +from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.model_executor.model_loader.weight_utils import ( download_weights_from_hf_specific, @@ -173,7 +174,9 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self._execution_device ) - self.transformer = ZImageTransformer2DModel() + # Get vLLM quantization config for linear layers + quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) + self.transformer = ZImageTransformer2DModel(quant_config=quant_config) self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) # Note: Context parallelism is applied centrally in registry.initialize_model() diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 8efb899312..67e6723438 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -18,6 +18,7 @@ import math from collections.abc import Iterable +from typing import TYPE_CHECKING import torch import torch.nn as nn @@ -32,6 +33,11 @@ ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.cache.base import CachedTransformer from vllm_omni.diffusion.distributed.sp_plan import ( @@ -250,6 +256,7 @@ def __init__( num_kv_heads: int, qk_norm: bool = True, eps: float = 1e-6, + quant_config: "QuantizationConfig | None" = None, ) -> None: super().__init__() self.dim = dim @@ -264,6 +271,7 @@ def __init__( total_num_heads=num_heads, total_num_kv_heads=num_kv_heads, bias=False, + quant_config=quant_config, ) assert qk_norm is True @@ -281,6 +289,7 @@ def __init__( bias=False, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) ] ) @@ -343,13 +352,19 @@ def forward( class FeedForward(nn.Module): - def __init__(self, dim: int, hidden_dim: int): + def __init__( + self, + dim: int, + hidden_dim: int, + quant_config: "QuantizationConfig | None" = None, + ): super().__init__() self.w13 = MergedColumnParallelLinear( dim, [hidden_dim] * 2, bias=False, return_bias=False, + quant_config=quant_config, ) self.act = SiluAndMul() self.w2 = RowParallelLinear( @@ -358,6 +373,7 @@ def __init__(self, dim: int, hidden_dim: int): bias=False, input_is_parallel=True, return_bias=False, + quant_config=quant_config, ) def forward(self, x): @@ -374,6 +390,7 @@ def __init__( norm_eps: float, qk_norm: bool, modulation=True, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.dim = dim @@ -384,9 +401,14 @@ def __init__( num_kv_heads=n_kv_heads, qk_norm=qk_norm, eps=1e-5, + quant_config=quant_config, ) - self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=int(dim / 3 * 8), + quant_config=quant_config, + ) self.layer_id = layer_id self.attention_norm1 = RMSNorm(dim, eps=norm_eps) @@ -589,6 +611,7 @@ def __init__( t_scale=1000.0, axes_dims=[32, 48, 48], axes_lens=[1024, 512, 512], + quant_config: "QuantizationConfig | None" = None, ) -> None: super().__init__() self.dtype = torch.bfloat16 @@ -648,6 +671,7 @@ def __init__( norm_eps, qk_norm, modulation=True, + quant_config=quant_config, ) for layer_id in range(n_refiner_layers) ] @@ -662,6 +686,7 @@ def __init__( norm_eps, qk_norm, modulation=False, + quant_config=quant_config, ) for layer_id in range(n_refiner_layers) ] @@ -677,7 +702,15 @@ def __init__( self.layers = nn.ModuleList( [ - ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + quant_config=quant_config, + ) for layer_id in range(n_layers) ] ) diff --git a/vllm_omni/diffusion/quantization/__init__.py b/vllm_omni/diffusion/quantization/__init__.py new file mode 100644 index 0000000000..cc1bb547f7 --- /dev/null +++ b/vllm_omni/diffusion/quantization/__init__.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Quantization support for diffusion models. + +This module provides a unified interface for quantizing diffusion transformers +using various methods (FP8, etc.). It wraps vLLM's quantization infrastructure +while allowing diffusion-model-specific defaults and optimizations. + +Example usage: + from vllm_omni.diffusion.quantization import ( + get_diffusion_quant_config, + get_vllm_quant_config_for_layers, + ) + + # Create FP8 config for diffusion model + diff_config = get_diffusion_quant_config("fp8") + + # Get vLLM config to pass to linear layers + vllm_config = get_vllm_quant_config_for_layers(diff_config) + + # Use in model initialization + linear_layer = QKVParallelLinear(..., quant_config=vllm_config) +""" + +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +from .base import DiffusionQuantizationConfig +from .fp8 import DiffusionFp8Config + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + +logger = init_logger(__name__) + +# Registry of supported quantization methods +# To add a new method, create a new config class and register it here +_QUANT_CONFIG_REGISTRY: dict[str, type[DiffusionQuantizationConfig]] = { + "fp8": DiffusionFp8Config, +} + +SUPPORTED_QUANTIZATION_METHODS = list(_QUANT_CONFIG_REGISTRY.keys()) + + +def get_diffusion_quant_config( + quantization: str | None, + **kwargs, +) -> DiffusionQuantizationConfig | None: + """Factory function to create quantization config for diffusion models. + + Args: + quantization: Quantization method name ("fp8", etc.) or None to disable + **kwargs: Method-specific parameters passed to the config constructor + + Returns: + DiffusionQuantizationConfig instance or None if quantization is disabled + + Raises: + ValueError: If the quantization method is not supported + + Example: + # Default FP8 with dynamic activation scaling + config = get_diffusion_quant_config("fp8") + + # FP8 with custom parameters + config = get_diffusion_quant_config( + "fp8", + activation_scheme="static", + ignored_layers=["proj_out"], + ) + """ + if quantization is None or quantization.lower() == "none": + return None + + quantization = quantization.lower() + if quantization not in _QUANT_CONFIG_REGISTRY: + raise ValueError( + f"Unknown quantization method: {quantization!r}. Supported methods: {SUPPORTED_QUANTIZATION_METHODS}" + ) + + config_cls = _QUANT_CONFIG_REGISTRY[quantization] + logger.info("Creating diffusion quantization config: %s", quantization) + return config_cls(**kwargs) + + +def get_vllm_quant_config_for_layers( + diffusion_quant_config: DiffusionQuantizationConfig | None, +) -> "QuantizationConfig | None": + """Get the vLLM QuantizationConfig to pass to linear layers. + + This extracts the underlying vLLM config from a DiffusionQuantizationConfig, + which can then be passed to vLLM linear layers (QKVParallelLinear, etc.). + + Args: + diffusion_quant_config: The diffusion quantization config, or None + + Returns: + vLLM QuantizationConfig instance, or None if input is None + """ + if diffusion_quant_config is None: + return None + return diffusion_quant_config.get_vllm_quant_config() + + +__all__ = [ + "DiffusionQuantizationConfig", + "DiffusionFp8Config", + "get_diffusion_quant_config", + "get_vllm_quant_config_for_layers", + "SUPPORTED_QUANTIZATION_METHODS", +] diff --git a/vllm_omni/diffusion/quantization/base.py b/vllm_omni/diffusion/quantization/base.py new file mode 100644 index 0000000000..0cd9e4147e --- /dev/null +++ b/vllm_omni/diffusion/quantization/base.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for diffusion model quantization configurations.""" + +from abc import ABC +from typing import TYPE_CHECKING, ClassVar + +import torch + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + + +class DiffusionQuantizationConfig(ABC): + """Base class for diffusion model quantization configurations. + + This provides a thin wrapper over vLLM's quantization configs, + allowing diffusion-model-specific defaults and future extensibility. + + Subclasses should: + - Set quant_config_cls to the vLLM QuantizationConfig class + - Call super().__init__() after creating self._vllm_config + - Optionally override get_name() and get_min_capability() if needed + """ + + # Subclasses should set this to the vLLM QuantizationConfig class + quant_config_cls: ClassVar[type["QuantizationConfig"] | None] = None + + # The underlying vLLM config instance + _vllm_config: "QuantizationConfig | None" = None + + @classmethod + def get_name(cls) -> str: + """Return the quantization method name (e.g., 'fp8', 'int8'). + + By default, delegates to the underlying vLLM config class. + """ + if cls.quant_config_cls is not None: + return cls.quant_config_cls.get_name() + raise NotImplementedError("Subclass must set quant_config_cls or override get_name()") + + def get_vllm_quant_config(self) -> "QuantizationConfig | None": + """Return the underlying vLLM QuantizationConfig for linear layers.""" + return self._vllm_config + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + """Return supported activation dtypes.""" + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + """Minimum GPU compute capability required. + + By default, delegates to the underlying vLLM config class. + """ + if cls.quant_config_cls is not None: + return cls.quant_config_cls.get_min_capability() + return 80 # Ampere default diff --git a/vllm_omni/diffusion/quantization/fp8.py b/vllm_omni/diffusion/quantization/fp8.py new file mode 100644 index 0000000000..68abf9c229 --- /dev/null +++ b/vllm_omni/diffusion/quantization/fp8.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FP8 quantization config for diffusion transformers.""" + +from vllm.model_executor.layers.quantization.fp8 import Fp8Config + +from .base import DiffusionQuantizationConfig + + +class DiffusionFp8Config(DiffusionQuantizationConfig): + """FP8 quantization config optimized for diffusion transformers. + + Uses dynamic activation scaling (no calibration dataset needed) and + online weight quantization from BF16/FP16 checkpoints. + + Device Compatibility: + - Turing (SM 75+): Weight-only FP8 via Marlin kernel + - Ada/Hopper (SM 89+): Full W8A8 FP8 with native hardware support + + The kernel selection is automatic based on GPU capability. + + Args: + activation_scheme: Activation quantization scheme. + - "dynamic": Per-token dynamic scaling (default, no calibration) + - "static": Single per-tensor scale (requires calibration) + weight_block_size: Block size for block-wise weight quantization. + Format: [block_n, block_k]. If None, uses per-tensor scaling. + ignored_layers: List of layer name patterns to skip quantization. + """ + + # Tight coupling with vLLM's Fp8Config - delegates get_name() and get_min_capability() + quant_config_cls = Fp8Config + + def __init__( + self, + activation_scheme: str = "dynamic", + weight_block_size: list[int] | None = None, + ignored_layers: list[str] | None = None, + ): + self.activation_scheme = activation_scheme + self.weight_block_size = weight_block_size + self.ignored_layers = ignored_layers or [] + + # Create underlying vLLM FP8 config + self._vllm_config = Fp8Config( + is_checkpoint_fp8_serialized=False, # Online quantization from BF16 + activation_scheme=activation_scheme, + weight_block_size=weight_block_size, + ignored_layers=ignored_layers, + )