Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions docs/user_guide/diffusion/quantization/fp8.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,40 @@

FP8 quantization converts BF16/FP16 weights to FP8 at model load time. No calibration or pre-quantized checkpoint needed.

vLLM-Omni supports FP8 quantization for three types of diffusion model components:

| Component | Layer Types | Mechanism | Memory Savings |
|-----------|------------|-----------|---------------|
| **DiT (transformer)** | `nn.Linear` | vLLM W8A8 quantized linear layers | ~50% weights + compute speedup |
| **Text encoder** | `nn.Linear` | FP8 weight storage with hooks | ~50% weights |
| **VAE** | `nn.Conv2d`, `nn.Conv3d` | FP8 weight storage with hooks | ~50% weights |

### DiT Quantization

For DiT linear layers, vLLM-Omni uses vLLM's native FP8 W8A8 quantization infrastructure. On Ada/Hopper GPUs (SM 89+), this provides both memory savings and inference speedup through hardware-accelerated FP8 compute.

Depending on the model, either all layers can be quantized, or some sensitive layers should stay in BF16. See the [per-model table](#supported-models) for which case applies.

Common sensitive layers in DiT-based diffusion models include **image-stream MLPs** (`img_mlp`). These are particularly vulnerable to FP8 precision loss because they process denoising latents whose dynamic range shifts significantly across timesteps, and unlike attention projections (which benefit from QK-Norm stabilization), MLPs have no built-in normalization to absorb quantization error. In deep architectures (e.g., 60+ residual blocks), small per-layer errors compound and degrade output quality. Other layers such as **attention projections** (`to_qkv`, `to_out`) and **text-stream MLPs** (`txt_mlp`) are generally more robust due to normalization or more stable input statistics.

### Text Encoder and VAE Quantization

For text encoders and VAEs loaded via `from_pretrained()`, vLLM-Omni uses **FP8 weight-only storage**. Weights are stored in `float8_e4m3fn` and dequantized to BF16 before each forward pass. This saves ~50% memory with no accuracy loss since computation still happens in BF16.

This approach is necessary because:

- **Text encoders** use standard `nn.Linear` layers but are loaded outside vLLM's weight pipeline
- **VAEs** use `nn.Conv2d`/`nn.Conv3d` layers, for which PyTorch has no FP8 compute kernels

The hook mechanism ensures only one layer's BF16 weight exists in memory at a time:

```
At rest: All weights stored in FP8 (half memory)
Pre-hook: Dequantize current layer's weight to BF16
Forward: Normal computation in BF16
Post-hook: Re-quantize weight back to FP8 (free BF16)
```

## Configuration

1. **Python API**: set `quantization="fp8"`. To skip sensitive layers, use `quantization_config` with `ignored_layers`.
Expand Down Expand Up @@ -56,13 +86,19 @@ vllm serve <your-model> --omni --quantization fp8

The available `ignored_layers` names depend on the model architecture (e.g., `to_qkv`, `to_out`, `img_mlp`, `txt_mlp`). Consult the transformer source for your target model.

!!! note
The `ignored_layers` parameter only applies to DiT linear layers. Text encoder and VAE FP8 weight storage is applied to all layers when quantization is enabled.

## Supported Models

| Model | HF Models | Recommendation | `ignored_layers` |
|-------|-----------|---------------|------------------|
| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | All layers | None |
| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | Skip sensitive layers | `img_mlp` |
| Flux | `black-forest-labs/FLUX.1-dev` | All layers | None |
| Model | HF Models | DiT FP8 | Text Encoder FP8 | VAE FP8 | `ignored_layers` |
|-------|-----------|:-------:|:-----------------:|:-------:|------------------|
| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | — | None |
| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | ✅ | ✅ | ✅ | `img_mlp` |
| Qwen-Image-Edit | `Qwen/Qwen-Image-Edit` | ✅ | ✅ | ✅ | — |
| Qwen-Image-Edit-Plus | `Qwen/Qwen-Image-Layered` | ✅ | ✅ | ✅ | — |
| Flux | `black-forest-labs/FLUX.1-dev` | ✅ | — | — | None |
| Wan 2.2 | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | ✅ | — | — | — |

## Combining with Other Features

Expand Down
34 changes: 34 additions & 0 deletions examples/offline_inference/image_to_video/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import os
import time
from pathlib import Path
from typing import Any

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -181,6 +182,23 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
parser.add_argument(
"--quantization",
type=str,
default=None,
choices=["fp8"],
help="Quantization method for the transformer. "
"Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). "
"Default: None (no quantization, uses BF16).",
)
parser.add_argument(
"--ignored-layers",
type=str,
default=None,
help="Comma-separated list of layer name patterns to skip quantization. "
"Only used when --quantization is set. "
"Example: --ignored-layers 'to_qkv,to_out'",
)
return parser.parse_args()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Same inconsistent API usage

Same concern as in text_to_video.py - consider unifying the quantization config format.


Expand Down Expand Up @@ -273,6 +291,18 @@ def main():
hsdp_shard_size=args.hsdp_shard_size,
hsdp_replicate_size=args.hsdp_replicate_size,
)

# Build quantization kwargs
quant_kwargs: dict[str, Any] = {}
ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None
if args.quantization and ignored_layers:
quant_kwargs["quantization_config"] = {
"method": args.quantization,
"ignored_layers": ignored_layers,
}
elif args.quantization:
quant_kwargs["quantization"] = args.quantization

omni = Omni(
model=args.model,
enable_layerwise_offload=args.enable_layerwise_offload,
Expand All @@ -287,6 +317,7 @@ def main():
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
**quant_kwargs,
)

if profiler_enabled:
Expand All @@ -303,6 +334,9 @@ def main():
f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size},"
f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}"
)
print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}")
if ignored_layers:
print(f" Ignored layers: {ignored_layers}")
print(f" Video size: {args.width}x{args.height}")
print(f"{'=' * 60}\n")

Expand Down
33 changes: 33 additions & 0 deletions examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -137,6 +138,23 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="Enable diffusion pipeline profiler to display stage durations.",
)
parser.add_argument(
"--quantization",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CLI interface is well-designed with clear help text. The --quantization and --ignored-layers args provide flexibility for users to experiment with different quantization strategies.

type=str,
default=None,
choices=["fp8"],
help="Quantization method for the transformer. "

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about text encoder?

@lishunyang12 lishunyang12 Feb 24, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hii, thanks for review.
The text encoder (UMT5) is not quantized here — same as
what Z-Image does. Only the diffusion transformer layers
get FP8. The text encoder is relatively small compared to
the transformer, so quantizing it has less impact on
memory while potentially hurting prompt embedding quality.
We could add text encoder quantization as a follow-up.

"Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs). "
"Default: None (no quantization, uses BF16).",
)
parser.add_argument(
"--ignored-layers",
type=str,
default=None,
help="Comma-separated list of layer name patterns to skip quantization. "
"Only used when --quantization is set. "
"Example: --ignored-layers 'to_qkv,to_out'",
)
return parser.parse_args()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Inconsistent API usage

The code uses two different approaches depending on whether ignored_layers is provided:

  • With ignored_layers: quantization_config dict with method and ignored_layers
  • Without: Simple quantization string

This could be confusing. Consider unifying to always use the same format:

if args.quantization:
    quant_kwargs["quantization_config"] = {
        "method": args.quantization,
        **(({"ignored_layers": ignored_layers} if ignored_layers else {}))
    }

Or verify that Omni handles both formats identically.


Expand Down Expand Up @@ -176,6 +194,17 @@ def main():
# Check if profiling is requested via environment variable
profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))

# Build quantization kwargs

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quantization_config dict construction properly handles both the quantization method and ignored_layers, matching the OmniDiffusionConfig expectations.

quant_kwargs: dict[str, Any] = {}
ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None
if args.quantization and ignored_layers:
quant_kwargs["quantization_config"] = {
"method": args.quantization,
"ignored_layers": ignored_layers,
}
elif args.quantization:
quant_kwargs["quantization"] = args.quantization

omni = Omni(
model=args.model,
enable_layerwise_offload=args.enable_layerwise_offload,
Expand All @@ -190,6 +219,7 @@ def main():
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
**quant_kwargs,
)

if profiler_enabled:
Expand All @@ -207,6 +237,9 @@ def main():
f" cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size},"
f" vae_patch_parallel_size={args.vae_patch_parallel_size}, enable_expert_parallel={args.enable_expert_parallel}"
)
print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}")
if ignored_layers:
print(f" Ignored layers: {ignored_layers}")
print(f" Video size: {args.width}x{args.height}")
print(f"{'=' * 60}\n")

Expand Down
17 changes: 16 additions & 1 deletion vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ def __init__(
self.vae = DistributedAutoencoderKLQwenImage.from_pretrained(
model, subfolder="vae", local_files_only=local_files_only
).to(self.device)

# Apply FP8 weight quantization to VAE and text encoder
if (
od_config.quantization_config is not None
and getattr(od_config.quantization_config, "quant_method", None) == "fp8"
):
from vllm_omni.diffusion.models.utils import apply_fp8_weight_storage

apply_fp8_weight_storage(self.vae)
apply_fp8_weight_storage(self.text_encoder)
transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config)
self.transformer = QwenImageTransformer2DModel(
Expand Down Expand Up @@ -724,4 +734,9 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
loaded_weights = loader.load_weights(weights)
# VAE and text_encoder are loaded via from_pretrained(), not through
# the weight pipeline, so mark their weights as loaded.
loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()}
loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()}
return loaded_weights
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,17 @@ def __init__(
self.vae = AutoencoderKLQwenImage.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to(
self.device
)

# Apply FP8 weight quantization to VAE and text encoder
if (
od_config.quantization_config is not None
and getattr(od_config.quantization_config, "quant_method", None) == "fp8"
):
from vllm_omni.diffusion.models.utils import apply_fp8_weight_storage

apply_fp8_weight_storage(self.vae)
apply_fp8_weight_storage(self.text_encoder)

transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
Expand Down Expand Up @@ -824,4 +835,9 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
loaded_weights = loader.load_weights(weights)
# VAE and text_encoder are loaded via from_pretrained(), not through
# the weight pipeline, so mark their weights as loaded.
loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()}
loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()}
return loaded_weights
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ def __init__(
self.device
)

# Apply FP8 weight quantization to VAE and text encoder
if (
od_config.quantization_config is not None
and getattr(od_config.quantization_config, "quant_method", None) == "fp8"
):
from vllm_omni.diffusion.models.utils import apply_fp8_weight_storage

apply_fp8_weight_storage(self.vae)
apply_fp8_weight_storage(self.text_encoder)

transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, QwenImageTransformer2DModel)
self.transformer = QwenImageTransformer2DModel(od_config=od_config, **transformer_kwargs)
self.tokenizer = Qwen2Tokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only)
Expand Down Expand Up @@ -781,4 +791,9 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
loaded_weights = loader.load_weights(weights)
# VAE and text_encoder are loaded via from_pretrained(), not through
# the weight pipeline, so mark their weights as loaded.
loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()}
loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()}
return loaded_weights
83 changes: 83 additions & 0 deletions vllm_omni/diffusion/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utility functions for diffusion models."""

import logging

import torch
from torch import nn

logger = logging.getLogger(__name__)

# Maximum value for float8_e4m3fn
_FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max


_FP8_TARGET_LAYERS = (nn.Linear, nn.Conv2d, nn.Conv3d)


def apply_fp8_weight_storage(model: nn.Module) -> None:
"""Apply FP8 weight-only storage to Linear/Conv2d/Conv3d layers.

Stores weights in float8_e4m3fn with per-tensor scales.
Dequantizes to the original compute dtype before each forward pass,
then re-quantizes afterward to free BF16 memory.

This saves ~50% of memory with no accuracy loss since computation
still happens in the original dtype.

Args:
model: The model whose layers will be quantized.
"""
count = 0
for name, module in model.named_modules():
if not isinstance(module, _FP8_TARGET_LAYERS):
continue

# P4: Idempotency guard -- skip if already quantized
if hasattr(module, "_fp8_weight"):
continue

weight = module.weight.data
compute_dtype = weight.dtype

# Compute per-tensor scale
amax = weight.abs().amax().clamp(min=1e-12)
scale = amax / _FP8_E4M3_MAX

# Quantize weight to FP8
fp8_weight = (weight / scale).clamp(min=-_FP8_E4M3_MAX, max=_FP8_E4M3_MAX).to(torch.float8_e4m3fn)

# Store FP8 weight and metadata as buffers (not parameters)
module.register_buffer("_fp8_weight", fp8_weight)
module.register_buffer("_fp8_scale", scale.to(torch.float32))
module._fp8_compute_dtype = compute_dtype

# P1: Keep the parameter at the original compute dtype so that
# model.dtype (derived from parameters) stays correct. We store
# the FP8 representation in the _fp8_weight buffer and only
# dequantize into module.weight.data inside the pre-hook.
# After forward, the post-hook swaps back to FP8 storage to
# free the BF16/FP16 memory.
module.weight.data = fp8_weight.to(compute_dtype)

def _pre_hook(mod, args):
# Dequantize: restore BF16/FP16 weight for computation
# P2: Cast back to compute dtype to avoid float32 promotion
# from the float32 scale tensor.
mod.weight.data = (mod._fp8_weight.to(mod._fp8_compute_dtype) * mod._fp8_scale).to(mod._fp8_compute_dtype)

def _post_hook(mod, args, output):
# Re-quantize: swap back to FP8-dequantized placeholder to
# free full-precision memory while keeping dtype correct.
mod.weight.data = mod._fp8_weight.to(mod._fp8_compute_dtype)

module.register_forward_pre_hook(_pre_hook)
module.register_forward_hook(_post_hook)
count += 1

logger.info(
"Applied FP8 weight storage to %d layers in %s",
count,
model.__class__.__name__,
)
Loading
Loading