Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions docs/user_guide/quantization/mxfp8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# W8A8 MXFP8 Quantization

## Overview

W8A8 MXFP8 (Microscaling FP8) quantizes both weights and activations to FP8
using the OCP MX format: groups of 32 K-dimension elements share a single
`float8_e8m0fnu` exponent scale. This gives better accuracy than channel-wise
FP8 while keeping the same 8-bit weight footprint.

This method supports two modes:

| Mode | Description |
|------|-------------|
| **Online** | BF16 weights are quantized to MXFP8 at load time — no pre-processing needed |
| **Offline** | msModelSlim-exported MXFP8 weights converted to diffusers format via `merge_mxfp8_checkpoint.py` — weights and scales are loaded directly from the preprocessed checkpoint |

## Hardware Support

| Device | Support |
|--------|---------|
| NVIDIA Blackwell GPU (SM 100+) | ⭕ |
| NVIDIA Ada/Hopper GPU (SM 89+) | ⭕ |
| NVIDIA Ampere GPU (SM 80+) | ⭕ |
| AMD ROCm | ⭕ |
| Intel XPU | ⭕ |
| Ascend NPU (Atlas 950 A5) | ✅ |

Legend: `✅` supported, `❌` unsupported, `⭕` not verified in this
guide.

## Model Type Support

### Diffusion Model (Wan2.2)

| Model | Mode | Notes |
|-------|------|-------|
| Wan2.2-T2V-A14B | Online + Offline | MoE cascade; quantizes two transformers (`transformer` + `transformer_2`) |
| Wan2.2-I2V-A14B | Online + Offline | MoE cascade; quantizes two transformers (`transformer` + `transformer_2`) |
| Wan2.2-TI2V-5B | Online + Offline | Single transformer |

### Multi-Stage Omni/TTS Model (Qwen3-Omni, Qwen3-TTS)

| Model | Status | Notes |
|-------|--------|-------|
| Qwen3-Omni | Not validated | — |
| Qwen3-TTS | Not validated | — |

### Multi-Stage Diffusion Model (BAGEL, GLM-Image)

| Model | Status | Notes |
|-------|--------|-------|
| BAGEL | Not validated | — |
| GLM-Image | Not validated | — |

## Configuration

### Online Mode

Online mode requires no pre-processing. vLLM-Omni quantizes BF16 weights to
MXFP8 at load time.

Python API:

```python
from vllm_omni import Omni
from vllm_omni.inputs.data import OmniDiffusionSamplingParams

omni = Omni(model="<your-model>", quantization="mxfp8")

outputs = omni.generate(
"A cat sitting on a windowsill",
OmniDiffusionSamplingParams(num_inference_steps=50),
)
```

CLI:

```bash
python text_to_video.py --model <your-model> --quantization mxfp8

# Online serving
vllm serve <your-model> --omni --quantization mxfp8
```

### Offline Mode

Offline mode loads a pre-quantized checkpoint from msModelSlim. A preprocessing
step converts the raw quantized output to the diffusers format expected by
vLLM-Omni and injects the quantization config into `transformer/config.json` so
that vLLM-Omni auto-detects the offline path without a `--quantization` flag.

#### Step 1 — Quantize with msModelSlim

```bash
msmodelslim quant \
--model_path /path/to/Wan2.2-TI2V-5B-Diffusers \
--save_path /path/to/wan2_2_ti2v_quantized_raw \
--device npu \
--model_type Wan2_2 \
--config_path /path/to/wan2_2_w8a8f8_mxfp.yaml \
--trust_remote_code True
```

After this step, `--save_path` contains the raw quantized safetensors files and
a metadata JSON (`quant_model_description*.json`).

For cascade MoE models (T2V-A14B, I2V-A14B), msModelSlim outputs two
subdirectories: `high_noise_model/` and `low_noise_model/`.

#### Step 2 — Preprocess with merge_mxfp8_checkpoint.py

The script (`vllm_omni/quantization/tools/merge_mxfp8_checkpoint.py`):

1. Copies the original diffusers model to `--output-path` (VAE, text encoder,
scheduler, etc. are preserved).
2. Remaps tensor names from msModelSlim convention to diffusers convention.
3. Saves the converted weights as `diffusion_pytorch_model.safetensors`.
4. Copies the original `transformer/config.json` and injects
`quantization_config` so that vLLM-Omni auto-detects offline MXFP8.

For cascade MoE models, steps 2–4 run separately for `high_noise_model/` →
`transformer/` and `low_noise_model/` → `transformer_2/`.

```bash
python vllm_omni/quantization/tools/merge_mxfp8_checkpoint.py \
--model-type Wan2.2-TI2V-5B \
--original-model /path/to/Wan2.2-TI2V-5B-Diffusers \
--quant-path /path/to/wan2_2_ti2v_quantized_raw \
--output-path /path/to/Wan2.2-TI2V-5B-MXFP8
```

| Argument | Description |
|----------|-------------|
| `--model-type` | Model variant: `Wan2.2-T2V-A14B`, `Wan2.2-I2V-A14B`, or `Wan2.2-TI2V-5B` |
| `--original-model` | Root directory of the original BF16 diffusers model |
| `--quant-path` | Root directory of the msModelSlim quantized output |
| `--output-path` | Output directory for the merged model (created by the script) |

The script outputs a complete diffusers model directory at `--output-path`,
with each transformer subfolder containing:

- `diffusion_pytorch_model.safetensors` — converted FP8 weights
- `config.json` — original transformer config with `quantization_config` injected
- `quant_model_description.json` — renamed quantization metadata (reference only)

#### Step 3 — Serve

```bash
python text_to_video.py --model /path/to/Wan2.2-TI2V-5B-MXFP8

# Online serving
vllm serve /path/to/Wan2.2-TI2V-5B-MXFP8 --omni
```

Python API:

```python
omni = Omni(model="/path/to/Wan2.2-TI2V-5B-MXFP8")
```

!!! note
No `--quantization` flag is needed for offline mode. The preprocessing
script injects `quantization_config` into each `transformer/config.json`,
which vLLM-Omni reads automatically to activate the offline MXFP8 method.

## Parameters

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `method` | str | — | Must be `"mxfp8"` |
| `is_checkpoint_mxfp8_serialized` | bool | `False` | `True` for offline pre-quantized checkpoints; auto-set from `config.json` when using the preprocessing script |
| `ignored_layers` | list[str] | `[]` | Layer name substrings to keep in BF16 (e.g. `"to_out"` matches `blocks.0.attn1.to_out.0`) |

## Validation and Notes

1. Online mode quantizes BF16 weights at load time using
`npu_dynamic_mx_quant`. This adds a one-time overhead on the first load
but requires no checkpoint preparation.
2. Offline mode loads FP8 weights directly from the checkpoint. Scales are
stored as `uint8` bytes in safetensors (same bit layout as
`float8_e8m0fnu`) and are reinterpreted at load time without a dtype
conversion.
3. If the offline checkpoint was produced with the old `merge_mxfp8_checkpoint.py`
interface (arguments `--quant-dir`, `--orig-dir`, `--meta-json`,
`--output-dir`), regenerate it with the current script. The old script
wrote a separate `quantization_config.json` that is not read by vLLM-Omni;
the current script injects the config directly into `transformer/config.json`.
23 changes: 13 additions & 10 deletions docs/user_guide/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ type has a different quantization scope.

| Mode | Guide | Description | Methods |
|------|-------|-------------|---------|
| Online quantization | [Online Quantization](online.md) | vLLM-Omni computes quantized weights and scales while loading the model. | FP8 W8A8, Int8 W8A8 |
| Online quantization | [Online Quantization](online.md) | vLLM-Omni computes quantized weights and scales while loading the model. | FP8 W8A8, Int8 W8A8, MXFP8 W8A8 |
| Runtime attention quantization | [Quantized KV Cache](quantized_kvcache.md) | vLLM-Omni dynamically quantizes eligible diffusion Flash Attention tensors during inference. | FP8 FA |
| Pre-quantized checkpoints | Method-specific guides | The checkpoint or an offline quantizer provides quantized weights and scales before serving. | ModelOpt, GGUF, AutoRound, msModelSlim, serialized Int8 |
| Pre-quantized checkpoints | Method-specific guides | The checkpoint or an offline quantizer provides quantized weights and scales before serving. | ModelOpt, GGUF, AutoRound, msModelSlim, serialized Int8, offline MXFP8 |

## Hardware Support

| Device | FP8 W8A8 | Int8 W8A8 | ModelOpt | GGUF | AutoRound | msModelSlim |
|--------|----------|-----------|----------|------|-----------|-------------|
| NVIDIA Blackwell GPU (SM 100+) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| NVIDIA Ada/Hopper GPU (SM 89+) | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| NVIDIA Ampere GPU (SM 80+) | ✅ | ✅ | ⭕ | ✅ | ✅ | ❌ |
| AMD ROCm | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ❌ |
| Intel XPU | ⭕ | ⭕ | ⭕ | ⭕ | ✅ | ❌ |
| Ascend NPU | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ |
| Device | FP8 W8A8 | Int8 W8A8 | ModelOpt | MXFP8 W8A8 | GGUF | AutoRound | msModelSlim |
|--------|----------|-----------|----------|------------|------|-----------|-------------|
| NVIDIA Blackwell GPU (SM 100+) | ✅ | ✅ | ✅ | ⭕ | ✅ | ✅ | ❌ |
| NVIDIA Ada/Hopper GPU (SM 89+) | ✅ | ✅ | ✅ | ⭕ | ✅ | ✅ | ❌ |
| NVIDIA Ampere GPU (SM 80+) | ✅ | ✅ | ⭕ | ⭕ | ✅ | ✅ | ❌ |
| AMD ROCm | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ❌ |
| Intel XPU | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ✅ | ❌ |
| Ascend NPU | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ |

Legend: `✅` supported, `❌` unsupported, `⭕` not verified in this
guide. FP8 on Ampere may use a weight-only path where available.
Expand All @@ -41,6 +41,7 @@ otherwise.
| FP8 W8A8 | [FP8](fp8.md) | Online W8A8 or checkpoint FP8 | Qwen-Image; Wan2.2 is not validated | Validated for Qwen-Image family and other DiT models |
| Int8 W8A8 | [Int8](int8.md) | Online or serialized W8A8 | Qwen-Image; Wan2.2 is not validated | Validated for Qwen-Image and Z-Image |
| ModelOpt | [ModelOpt](modelopt.md) | Pre-quantized FP8 checkpoints | Qwen-Image, Z-Image, FLUX.2, HunyuanImage-3.0 | Validated for ModelOpt FP8 diffusion checkpoints |
| MXFP8 W8A8 | [MXFP8](mxfp8.md) | Online W8A8 or offline pre-quantized | Wan2.2-T2V-A14B, I2V-A14B, TI2V-5B | Ascend NPU only; validated for Wan2.2 |
| GGUF | [GGUF](gguf.md) | Pre-quantized transformer weights | Qwen-Image | Validated where a model-specific GGUF adapter exists |
| AutoRound | [AutoRound](autoround.md) | Pre-quantized W4A16 checkpoints | FLUX.1-dev; Qwen-Image/Wan2.2 not validated | Checkpoint-driven |
| msModelSlim | [msModelSlim](msmodelslim.md) | Pre-quantized Ascend checkpoints | Wan2.2 recipe; HunyuanImage-3.0 inference target | Ascend/NPU path |
Expand All @@ -56,6 +57,7 @@ in BF16 unless the model guide explicitly adds support.
|--------|-------|-------|----------------|--------|
| ModelOpt | [ModelOpt](modelopt.md) | Thinker or language-model checkpoint config | Qwen3-Omni thinker | ModelOpt checkpoint path |
| Int8 | [Int8](int8.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated |
| MXFP8 | [MXFP8](mxfp8.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated |
| GGUF | [GGUF](gguf.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated |
| AutoRound | [AutoRound](autoround.md) | Thinker or language-model checkpoint config | Qwen2.5-Omni, Qwen3-Omni | Supported through AutoRound checkpoints |
| msModelSlim | [msModelSlim](msmodelslim.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated |
Expand All @@ -70,6 +72,7 @@ attached to the intended stage rather than applied globally.
| FP8 | [FP8](fp8.md) | Stage-specific DiT or transformer module | BAGEL, GLM-Image | Requires model-specific validation |
| Int8 | [Int8](int8.md) | Stage-specific DiT or transformer module | BAGEL, GLM-Image | Requires model-specific validation |
| ModelOpt | [ModelOpt](modelopt.md) | Checkpoint-defined diffusion stage | BAGEL, GLM-Image | Requires model-specific validation |
| MXFP8 | [MXFP8](mxfp8.md) | Stage-specific DiT or transformer module | BAGEL, GLM-Image | Not validated |
| GGUF | [GGUF](gguf.md) | Stage-specific transformer weights | BAGEL, GLM-Image | No validated adapter listed |
| AutoRound | [AutoRound](autoround.md) | Checkpoint-defined stage | BAGEL, GLM-Image | No validated checkpoint listed |
| msModelSlim | [msModelSlim](msmodelslim.md) | Ascend-generated stage weights | GLM-Image | Requires model-specific adaptation |
Expand Down
12 changes: 11 additions & 1 deletion examples/offline_inference/image_to_video/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ def parse_args() -> argparse.Namespace:
"Default 1 means pure sharding (no replication). "
),
)
parser.add_argument(
"--quantization",
type=str,
default=None,
choices=["fp8", "mxfp8", "int8", "gguf"],
help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 online quant (NPU). fp8: online FP8 (GPU).",
)
parser.add_argument(
"--enable-diffusion-pipeline-profiler",
action="store_true",
Expand Down Expand Up @@ -320,7 +327,7 @@ def main():
hsdp_shard_size=args.hsdp_shard_size,
hsdp_replicate_size=args.hsdp_replicate_size,
)
omni = Omni(
omni_kwargs = dict(
model=args.model,
enable_layerwise_offload=args.enable_layerwise_offload,
vae_use_slicing=args.vae_use_slicing,
Expand All @@ -339,6 +346,9 @@ def main():
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
profiler_config=args.profiler_config,
)
if args.quantization is not None:
omni_kwargs["quantization"] = args.quantization
omni = Omni(**omni_kwargs)

if profiler_enabled:
print("[Profiler] Starting profiling...")
Expand Down
4 changes: 2 additions & 2 deletions examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def parse_args() -> argparse.Namespace:
"--quantization",
type=str,
default=None,
choices=["fp8", "gguf"],
help="Quantization method for the transformer (fp8 for online FP8 quantization).",
choices=["fp8", "mxfp8", "int8", "gguf"],
help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 online quant (NPU). fp8: online FP8 (GPU).",
)
parser.add_argument(
"--use-hsdp",
Expand Down
65 changes: 61 additions & 4 deletions vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from diffusers.utils.torch_utils import randn_tensor
from torch import nn
from transformers import AutoTokenizer, UMT5EncoderModel
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.models.utils import AutoWeightsLoader

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
Expand Down Expand Up @@ -118,9 +119,12 @@ def load_transformer_config(model_path: str, subfolder: str = "transformer", loc
return {}


def create_transformer_from_config(config: dict) -> WanTransformer3DModel:
def create_transformer_from_config(
config: dict,
quant_config: QuantizationConfig | None = None,
) -> WanTransformer3DModel:
"""Create WanTransformer3DModel from config dict."""
kwargs = {}
kwargs: dict = {}

if "patch_size" in config:
kwargs["patch_size"] = tuple(config["patch_size"])
Expand Down Expand Up @@ -153,6 +157,58 @@ def create_transformer_from_config(config: dict) -> WanTransformer3DModel:
if "pos_embed_seq_len" in config:
kwargs["pos_embed_seq_len"] = config["pos_embed_seq_len"]

# Auto-detect quantization from transformer's config.json when not explicitly provided.
# merge_mxfp8_checkpoint.py injects quantization_config into config.json so that
# offline quantized checkpoints are recognized here without a CLI flag.
if "quantization_config" in config:
from vllm_omni.quantization.factory import build_quant_config

disk_qc = config["quantization_config"]
if isinstance(disk_qc, dict) and "quant_method" in disk_qc:
qc_method = disk_qc["quant_method"]
qc_kwargs = {k: v for k, v in disk_qc.items() if k != "quant_method"}
if quant_config is None:
# No CLI flag: full auto-detection.
quant_config = build_quant_config(qc_method, **qc_kwargs)
logger.info(
"Auto-detected quantization from transformer config.json: method=%s kwargs=%s",
qc_method,
qc_kwargs,
)
elif quant_config.get_name() != qc_method:
# The caller supplied a quant_config of a different method than
# what the checkpoint was built with. Loading serialized tensors
# (e.g. MXFP8 weight scales) with the wrong linear method would
# produce corrupt output or a shape mismatch crash. Reject early
# so the user gets a clear message instead of a silent failure.
raise ValueError(
f"Checkpoint config.json declares quant_method={qc_method!r} but the "
f"active quantization config is {quant_config.get_name()!r}. "
"Pass a matching --quantization flag or omit it for auto-detection."
)
elif (
Comment thread
hxhhhlalala marked this conversation as resolved.
qc_kwargs.get("is_checkpoint_mxfp8_serialized", False)
and hasattr(quant_config, "is_checkpoint_mxfp8_serialized")
and not quant_config.is_checkpoint_mxfp8_serialized
):
# Same method: CLI provided online mode but config.json marks this
# as a pre-quantized offline checkpoint. Switch to offline mode so
# users can pass --quantization mxfp8 without knowing the
# online/offline distinction.
quant_config = build_quant_config(qc_method, **qc_kwargs)
logger.info(
"config.json marks checkpoint as serialized; switching from online to offline MXFP8 mode.",
)
elif isinstance(disk_qc, str) and quant_config is None:
quant_config = build_quant_config(disk_qc)
logger.info(
"Auto-detected quantization from transformer config.json: method=%s",
disk_qc,
)

if quant_config is not None:
kwargs["quant_config"] = quant_config

return WanTransformer3DModel(**kwargs)


Expand Down Expand Up @@ -371,8 +427,9 @@ def __init__(
)

def _create_transformer(self, config: dict) -> WanTransformer3DModel:
"""Create a transformer from a config dict. Subclasses may override."""
return create_transformer_from_config(config)
"""Create a transformer from a config dict. Respects od_config.quantization_config."""
quant_config = getattr(self.od_config, "quantization_config", None)
return create_transformer_from_config(config, quant_config=quant_config)

@property
def guidance_scale(self):
Expand Down
Loading
Loading