From 6c716351cec050d50e63f8b1581dd195137aaba2 Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Wed, 13 May 2026 17:04:45 +0800 Subject: [PATCH 1/8] w4a4 online & offline quant Signed-off-by: hyh_hh --- docs/user_guide/quantization/mxfp4.md | 247 ++++++++ docs/user_guide/quantization/overview.md | 25 +- .../image_to_video/image_to_video.py | 4 +- .../text_to_video/text_to_video.py | 4 +- .../models/wan2_2/pipeline_wan2_2.py | 71 ++- vllm_omni/quantization/factory.py | 16 + vllm_omni/quantization/mixed_mxfp_config.py | 149 +++++ vllm_omni/quantization/mxfp4_config.py | 558 ++++++++++++++++++ .../tools/merge_mxfp4_dualscale_checkpoint.py | 471 +++++++++++++++ 9 files changed, 1524 insertions(+), 21 deletions(-) create mode 100644 docs/user_guide/quantization/mxfp4.md create mode 100644 vllm_omni/quantization/mixed_mxfp_config.py create mode 100644 vllm_omni/quantization/mxfp4_config.py create mode 100644 vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py diff --git a/docs/user_guide/quantization/mxfp4.md b/docs/user_guide/quantization/mxfp4.md new file mode 100644 index 00000000000..2b3f20bfd1a --- /dev/null +++ b/docs/user_guide/quantization/mxfp4.md @@ -0,0 +1,247 @@ +# W4A4 MXFP4 Quantization + +## Overview + +W4A4 MXFP4 (Microscaling FP4) quantizes both weights and activations to FP4 +(`float4_e2m1fn_x2`, packed 2 values per byte) using the OCP MX format: groups +of 32 K-dimension elements share a single `float8_e8m0fnu` exponent scale. + +This method supports two modes that differ significantly in scale structure and +checkpoint format: + +| Mode | Scale structure | Description | +|------|----------------|-------------| +| **Online** | Single-scale (per-32 fine only) | BF16 weights are quantized to MXFP4 at load time — no pre-processing needed | +| **Offline** | Dual-scale (fine per-32 + coarse per-512 + per-channel smooth pre-scale) | msModelSlim-exported MXFP4 DualScale weights converted to diffusers format via preprocessing script — all scale tensors are loaded directly from the checkpoint | + +!!! warning "Online ≠ Offline" + Online mode uses a **single-scale** (`NPUMxfp4OnlineLinearMethod`): one + `float8_e8m0fnu` exponent per 32 K elements, computed on the fly from the + BF16 weight. Offline mode uses a **dual-scale** (`NPUMxfp4DualScaleLinearMethod`): a + fine scale (per-32 K), a coarse scale (per-512 K), and a per-input-channel + smooth pre-scale (`mul_scale`) produced by calibration. The two levels and + the smooth pre-scale are all stored in the checkpoint; loading an offline + checkpoint with the online method (or vice versa) will produce incorrect + results. + +## Hardware Support + +| Device | Support | +|--------|---------| +| NVIDIA Blackwell GPU (SM 100+) | ⭕ | +| NVIDIA Ada/Hopper GPU (SM 89+) | ⭕ | +| NVIDIA Ampere GPU (SM 80+) | ⭕ | +| AMD ROCm | ⭕ | +| Intel XPU | ⭕ | +| Ascend NPU (Atlas 950 A5) | ✅ | + +Legend: `✅` supported, `❌` unsupported, `⭕` not verified in this guide. + +## Model Type Support + +### Diffusion Model (Wan2.2) + +| Model | Mode | Notes | +|-------|------|-------| +| Wan2.2-T2V-A14B | Online + Offline | MoE cascade; quantizes two transformers (`transformer` + `transformer_2`); offline uses mixed MXFP8 (early blocks) + MXFP4 DualScale (remaining blocks) | +| Wan2.2-I2V-A14B | Online + Offline | MoE cascade; same mixed-precision scheme as T2V-A14B | +| Wan2.2-TI2V-5B | ❌ Not supported | Parameter count too small; W4A4 quantization causes unacceptable accuracy loss | + +!!! note "Mixed MXFP8 + MXFP4 for cascade models" + For the A14B cascade models, the offline checkpoint uses + `quant_method: mxfp8_mxfp4_dualscale`: the first `num_mxfp8_blocks` + transformer blocks are stored as MXFP8 (W8A8), and the remaining blocks as + MXFP4 DualScale (W4A4). The split is recorded in the injected + `quantization_config` and is transparent to the serving command. + +!!! warning "TI2V-5B not supported" + Wan2.2-TI2V-5B is excluded from W4A4 quantization. Its smaller parameter + count makes it significantly more sensitive to 4-bit quantization noise, + resulting in unacceptable accuracy loss. Use [MXFP8](mxfp8.md) for TI2V-5B. + +## Configuration + +### Online Mode + +Online mode requires no pre-processing. vLLM-Omni quantizes BF16 weights to +MXFP4 at load time using `npu_dynamic_mx_quant`. A single block scale +(`float8_e8m0fnu`, one per 32 K elements) is computed on the fly; no +calibration `mul_scale` is available. + +Python API: + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni(model="", quantization="mxfp4") + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + +CLI: + +```bash +python text_to_video.py --model --quantization mxfp4 + +# Online serving +vllm serve --omni --quantization mxfp4 +``` + +### Offline Mode (DualScale) + +Offline mode loads a pre-quantized DualScale checkpoint from msModelSlim. A +preprocessing step converts the raw quantized output to the diffusers format +expected by vLLM-Omni and injects the quantization config into +`transformer/config.json` so that vLLM-Omni auto-detects the offline path +without a `--quantization` flag. + +#### Checkpoint tensor layout + +Each quantized linear layer stores four tensors: + +| Tensor | Shape | dtype | Description | +|--------|-------|-------|-------------| +| `weight` | `(N, K)` | float8_e4m3fn | FP4 packed (2 values per byte) | +| `weight_scale` | `(N, K//32)` | uint8 | Fine block scale (`float8_e8m0fnu` bit pattern) | +| `weight_dual_scale` | `(N, K//512, 1)` | float32 | Coarse block scale | +| `mul_scale` | `(K,)` | float32 | Per-input-channel smooth pre-scale (from calibration) | + +#### Step 1 — Quantize with msModelSlim + +```bash +msmodelslim quant \ + --model_path /path/to/Wan2.2-T2V-A14B-Diffusers \ + --save_path /path/to/wan2_2_t2v_quantized_raw \ + --device npu \ + --model_type Wan2_2 \ + --config_path /path/to/wan2_2_w4a4_mxfp4_dualscale.yaml \ + --trust_remote_code True +``` + +After this step, `--save_path` contains the raw quantized safetensors files, +scale files, and a metadata JSON (`quant_model_description*.json`). + +For cascade MoE models (T2V-A14B, I2V-A14B), msModelSlim outputs two +subdirectories: `high_noise_model/` and `low_noise_model/`. + +#### Step 2 — Preprocess with merge_mxfp4_dualscale_checkpoint.py + +The script (`vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py`): + +1. Copies the original diffusers model to `--output-path` (VAE, text encoder, + scheduler, etc. are preserved). +2. Remaps tensor names from msModelSlim convention to diffusers convention. +3. Saves the converted weights, fine/coarse scales, and `mul_scale` as + `diffusion_pytorch_model.safetensors`. +4. Copies the original `transformer/config.json` and injects + `quantization_config` so that vLLM-Omni auto-detects offline MXFP4 + DualScale. + +For cascade MoE models, steps 2–4 run separately for `high_noise_model/` → +`transformer/` and `low_noise_model/` → `transformer_2/`. + +```bash +python vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py \ + --model-type Wan2.2-T2V-A14B \ + --original-model /path/to/Wan2.2-T2V-A14B-Diffusers \ + --quant-path /path/to/wan2_2_t2v_quantized_raw \ + --output-path /path/to/Wan2.2-T2V-A14B-MXFP4-DualScale +``` + +| Argument | Description | +|----------|-------------| +| `--model-type` | Model variant: `Wan2.2-T2V-A14B` or `Wan2.2-I2V-A14B` | +| `--original-model` | Root directory of the original BF16 diffusers model | +| `--quant-path` | Root directory of the msModelSlim quantized output | +| `--output-path` | Output directory for the merged model (created by the script) | + +The script outputs a complete diffusers model directory at `--output-path`, +with each transformer subfolder containing: + +- `diffusion_pytorch_model.safetensors` — converted FP4 weights, fine/coarse scales, and `mul_scale` +- `config.json` — original transformer config with `quantization_config` injected +- `quant_model_description.json` — renamed quantization metadata (reference only) + +The `quantization_config` injected into `config.json` for each transformer: + +```json +{ + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 5, + "is_checkpoint_serialized": true +} +``` + +#### Step 3 — Serve + +```bash +python text_to_video.py --model /path/to/Wan2.2-T2V-A14B-MXFP4-DualScale + +# Online serving +vllm serve /path/to/Wan2.2-T2V-A14B-MXFP4-DualScale --omni +``` + +Python API: + +```python +omni = Omni(model="/path/to/Wan2.2-T2V-A14B-MXFP4-DualScale") +``` + +!!! note + No `--quantization` flag is needed for offline mode. The preprocessing + script injects `quantization_config` into each `transformer/config.json`, + which vLLM-Omni reads automatically to activate the offline MXFP4 + DualScale method. + +## Parameters + +### Online Mode (`mxfp4`) + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `method` | str | — | Must be `"mxfp4"` | +| `is_checkpoint_mxfp4_serialized` | bool | `False` | Set `True` to load a single-scale offline checkpoint; leave `False` (default) for online BF16-to-FP4 quantization | +| `ignored_layers` | list[str] | `[]` | Layer name substrings to keep in BF16 | + +### Offline DualScale Mode (`mxfp8_mxfp4_dualscale`) + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `method` | str | — | Must be `"mxfp8_mxfp4_dualscale"` | +| `num_mxfp8_blocks` | int | `0` | Number of leading transformer blocks kept as MXFP8; remaining blocks use MXFP4 DualScale | +| `is_checkpoint_serialized` | bool | `False` | `True` for offline DualScale checkpoints; auto-set from `config.json` when using the preprocessing script | +| `ignored_layers` | list[str] | `[]` | Layer name substrings to keep in BF16 | + +## Validation and Notes + +1. **Online mode** quantizes BF16 weights at load time using + `npu_dynamic_mx_quant` (single-scale). This adds a one-time overhead on the + first load but requires no checkpoint preparation. No calibration + `mul_scale` is available — all output partitions receive an identity + pre-scale. + +2. **Offline DualScale mode** loads four tensors per quantized layer: the FP4 + packed weight, a fine block scale (`uint8` interpreted as + `float8_e8m0fnu`), a coarse block scale (`float32`), and a per-input-channel + smooth pre-scale (`mul_scale`, `float32`). The `mul_scale` is derived from + calibration and applied to the activation before dual-level quantization + (`npu_dynamic_dual_level_mx_quant`), improving accuracy compared to the + online single-scale path. + +3. **Scale dtype**: fine scales are stored as `uint8` in safetensors (same bit + layout as `float8_e8m0fnu`) and are reinterpreted at load time without a + dtype conversion, avoiding a lossy float32 round-trip. + +4. **Self-attention QKV fusion**: the Q, K, V projection weights are fused into + a single `QKVParallelLinear` layer. Their `mul_scale` tensors are identical + (all three projections share the same input), so the three sequential loads + are idempotent. + +5. W4A4 carries inherently higher quantization noise than W8A8 (16 vs 256 + quantization levels). The DualScale offline method mitigates this with + calibrated `mul_scale` smooth quantization; online single-scale mode trades + accuracy for the convenience of not requiring a pre-processed checkpoint. diff --git a/docs/user_guide/quantization/overview.md b/docs/user_guide/quantization/overview.md index c4605c8c2d4..de6c6388231 100644 --- a/docs/user_guide/quantization/overview.md +++ b/docs/user_guide/quantization/overview.md @@ -9,20 +9,20 @@ type has a different quantization scope. | Mode | Guide | Description | Methods | |------|-------|-------------|---------| -| Online quantization | [Online Quantization](online.md) | vLLM-Omni computes quantized weights and scales while loading the model. | FP8 W8A8, Int8 W8A8, MXFP8 W8A8 | +| Online quantization | [Online Quantization](online.md) | vLLM-Omni computes quantized weights and scales while loading the model. | FP8 W8A8, Int8 W8A8, MXFP8 W8A8, MXFP4 W4A4 | | Runtime attention quantization | [Quantized KV Cache](quantized_kvcache.md) | vLLM-Omni dynamically quantizes eligible diffusion Flash Attention tensors during inference. | FP8 FA | -| Pre-quantized checkpoints | Method-specific guides | The checkpoint or an offline quantizer provides quantized weights and scales before serving. | ModelOpt, GGUF, AutoRound, msModelSlim, serialized Int8, offline MXFP8 | +| Pre-quantized checkpoints | Method-specific guides | The checkpoint or an offline quantizer provides quantized weights and scales before serving. | ModelOpt, GGUF, AutoRound, msModelSlim, serialized Int8, offline MXFP8, offline MXFP4 DualScale | ## Hardware Support -| Device | FP8 W8A8 | Int8 W8A8 | ModelOpt | MXFP8 W8A8 | GGUF | AutoRound | msModelSlim | -|--------|----------|-----------|----------|------------|------|-----------|-------------| -| NVIDIA Blackwell GPU (SM 100+) | ✅ | ✅ | ✅ | ⭕ | ✅ | ✅ | ❌ | -| NVIDIA Ada/Hopper GPU (SM 89+) | ✅ | ✅ | ✅ | ⭕ | ✅ | ✅ | ❌ | -| NVIDIA Ampere GPU (SM 80+) | ✅ | ✅ | ⭕ | ⭕ | ✅ | ✅ | ❌ | -| AMD ROCm | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ❌ | -| Intel XPU | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ✅ | ❌ | -| Ascend NPU | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | +| Device | FP8 W8A8 | Int8 W8A8 | ModelOpt | MXFP8 W8A8 | MXFP4 W4A4 | GGUF | AutoRound | msModelSlim | +|--------|----------|-----------|----------|------------|------------|------|-----------|-------------| +| NVIDIA Blackwell GPU (SM 100+) | ✅ | ✅ | ✅ | ⭕ | ⭕ | ✅ | ✅ | ❌ | +| NVIDIA Ada/Hopper GPU (SM 89+) | ✅ | ✅ | ✅ | ⭕ | ⭕ | ✅ | ✅ | ❌ | +| NVIDIA Ampere GPU (SM 80+) | ✅ | ✅ | ⭕ | ⭕ | ⭕ | ✅ | ✅ | ❌ | +| AMD ROCm | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ❌ | +| Intel XPU | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | ✅ | ❌ | +| Ascend NPU | ❌ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | Legend: `✅` supported, `❌` unsupported, `⭕` not verified in this guide. FP8 on Ampere may use a weight-only path where available. @@ -42,6 +42,7 @@ otherwise. | Int8 W8A8 | [Int8](int8.md) | Online or serialized W8A8 | Qwen-Image; Wan2.2 is not validated | Validated for Qwen-Image and Z-Image | | ModelOpt | [ModelOpt](modelopt.md) | Pre-quantized FP8 checkpoints | Qwen-Image, Z-Image, FLUX.2, HunyuanImage-3.0 | Validated for ModelOpt FP8 diffusion checkpoints | | MXFP8 W8A8 | [MXFP8](mxfp8.md) | Online W8A8 or offline pre-quantized | Wan2.2-T2V-A14B, I2V-A14B, TI2V-5B | Ascend NPU only; validated for Wan2.2 | +| MXFP4 W4A4 | [MXFP4](mxfp4.md) | Online W4A4 (single-scale) or offline DualScale pre-quantized | Wan2.2-T2V-A14B, I2V-A14B | Ascend NPU only; validated for Wan2.2 A14B cascade models; TI2V-5B not supported (accuracy loss too large); offline uses dual-scale with calibrated `mul_scale` | | GGUF | [GGUF](gguf.md) | Pre-quantized transformer weights | Qwen-Image | Validated where a model-specific GGUF adapter exists | | AutoRound | [AutoRound](autoround.md) | Pre-quantized W4A16 checkpoints | FLUX.1-dev; Qwen-Image/Wan2.2 not validated | Checkpoint-driven | | msModelSlim | [msModelSlim](msmodelslim.md) | Pre-quantized Ascend checkpoints | Wan2.2 recipe; HunyuanImage-3.0 inference target | Ascend/NPU path | @@ -58,6 +59,7 @@ in BF16 unless the model guide explicitly adds support. | ModelOpt | [ModelOpt](modelopt.md) | Thinker or language-model checkpoint config | Qwen3-Omni thinker | ModelOpt checkpoint path | | Int8 | [Int8](int8.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated | | MXFP8 | [MXFP8](mxfp8.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated | +| MXFP4 | [MXFP4](mxfp4.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated | | GGUF | [GGUF](gguf.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated | | AutoRound | [AutoRound](autoround.md) | Thinker or language-model checkpoint config | Qwen2.5-Omni, Qwen3-Omni | Supported through AutoRound checkpoints | | msModelSlim | [msModelSlim](msmodelslim.md) | Not currently validated for omni/TTS stages | Qwen3-Omni, Qwen3-TTS | Not validated | @@ -73,6 +75,7 @@ attached to the intended stage rather than applied globally. | Int8 | [Int8](int8.md) | Stage-specific DiT or transformer module | BAGEL, GLM-Image | Requires model-specific validation | | ModelOpt | [ModelOpt](modelopt.md) | Checkpoint-defined diffusion stage | BAGEL, GLM-Image | Requires model-specific validation | | MXFP8 | [MXFP8](mxfp8.md) | Stage-specific DiT or transformer module | BAGEL, GLM-Image | Not validated | +| MXFP4 | [MXFP4](mxfp4.md) | Stage-specific DiT or transformer module | BAGEL, GLM-Image | Not validated | | GGUF | [GGUF](gguf.md) | Stage-specific transformer weights | BAGEL, GLM-Image | No validated adapter listed | | AutoRound | [AutoRound](autoround.md) | Checkpoint-defined stage | BAGEL, GLM-Image | No validated checkpoint listed | | msModelSlim | [msModelSlim](msmodelslim.md) | Ascend-generated stage weights | GLM-Image | Requires model-specific adaptation | @@ -100,7 +103,7 @@ config = build_quant_config({ | Component | Default quantized? | Notes | |-----------|--------------------|-------| -| Diffusion transformer | Yes | Primary target for FP8, Int8, ModelOpt, GGUF, AutoRound, and msModelSlim | +| Diffusion transformer | Yes | Primary target for FP8, Int8, ModelOpt, MXFP8, MXFP4, GGUF, AutoRound, and msModelSlim | | Text encoder | No | Keep BF16 unless a method-specific guide documents support | | VAE | No | Keep BF16; storage-only paths are method-specific | | Scheduler/tokenizer | No | Loaded from the base model repository | diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index 84fbf2a94ca..a4bcf8ce71c 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -222,8 +222,8 @@ def parse_args() -> argparse.Namespace: "--quantization", type=str, default=None, - choices=["fp8", "mxfp8", "int8", "gguf"], - help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 online quant (NPU). fp8: online FP8 (GPU).", + choices=["fp8", "mxfp8", "mxfp4", "mxfp8_mxfp4_dualscale", "int8", "gguf"], + help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 (NPU). mxfp4: W4A4 MXFP4 (NPU). mxfp8_mxfp4_dualscale: mixed MXFP8+MXFP4 dual-scale (NPU). fp8: online FP8 (GPU).", ) parser.add_argument( "--enable-diffusion-pipeline-profiler", diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index b19f0095e64..0df24baf152 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -206,8 +206,8 @@ def parse_args() -> argparse.Namespace: "--quantization", type=str, default=None, - choices=["fp8", "mxfp8", "int8", "gguf"], - help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 online quant (NPU). fp8: online FP8 (GPU).", + choices=["fp8", "mxfp8", "mxfp4", "mxfp8_mxfp4_dualscale", "int8", "gguf"], + help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 (NPU). mxfp4: W4A4 MXFP4 (NPU). mxfp8_mxfp4_dualscale: mixed MXFP8+MXFP4 dual-scale (NPU). fp8: online FP8 (GPU).", ) parser.add_argument( "--use-hsdp", diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index e3286cf4f50..4a4581787cb 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -119,6 +119,21 @@ def load_transformer_config(model_path: str, subfolder: str = "transformer", loc return {} +_SERIALIZED_FLAGS = ( + "is_checkpoint_mxfp8_serialized", + "is_checkpoint_mxfp4_serialized", + "is_checkpoint_serialized", +) + + +def _disk_marks_serialized(qc_kwargs: dict, quant_config: object) -> bool: + """Return True when config.json says serialized but the active quant_config does not.""" + for flag in _SERIALIZED_FLAGS: + if qc_kwargs.get(flag, False) and hasattr(quant_config, flag) and not getattr(quant_config, flag): + return True + return False + + def create_transformer_from_config( config: dict, quant_config: QuantizationConfig | None = None, @@ -186,18 +201,30 @@ def create_transformer_from_config( f"active quantization config is {quant_config.get_name()!r}. " "Pass a matching --quantization flag or omit it for auto-detection." ) - elif ( - qc_kwargs.get("is_checkpoint_mxfp8_serialized", False) - and hasattr(quant_config, "is_checkpoint_mxfp8_serialized") - and not quant_config.is_checkpoint_mxfp8_serialized - ): + elif _disk_marks_serialized(qc_kwargs, quant_config): # Same method: CLI provided online mode but config.json marks this # as a pre-quantized offline checkpoint. Switch to offline mode so # users can pass --quantization mxfp8 without knowing the # online/offline distinction. quant_config = build_quant_config(qc_method, **qc_kwargs) logger.info( - "config.json marks checkpoint as serialized; switching from online to offline MXFP8 mode.", + "config.json marks checkpoint as serialized; switching to offline %s mode.", + qc_method, + ) + elif ( + "num_mxfp8_blocks" in qc_kwargs + and hasattr(quant_config, "num_mxfp8_blocks") + and qc_kwargs["num_mxfp8_blocks"] != quant_config.num_mxfp8_blocks + ): + # The transformer's own config.json has a different num_mxfp8_blocks than + # the active quant_config (e.g. built from a stale enriched config or a + # different transformer's config.json in a cascade model). Rebuild from + # disk so the block routing is authoritative for this transformer. + quant_config = build_quant_config(qc_method, **qc_kwargs) + logger.info( + "Disk config.json num_mxfp8_blocks=%d differs from active config; " + "rebuilding quant_config from transformer config.json.", + qc_kwargs["num_mxfp8_blocks"], ) elif isinstance(disk_qc, str) and quant_config is None: quant_config = build_quant_config(disk_qc) @@ -429,6 +456,38 @@ def __init__( def _create_transformer(self, config: dict) -> WanTransformer3DModel: """Create a transformer from a config dict. Respects od_config.quantization_config.""" quant_config = getattr(self.od_config, "quantization_config", None) + + # When od_config.quantization_config is None (no CLI --quantization flag), pre-build + # the quant_config from the transformer's own config.json and propagate it back to + # od_config. This has two effects: + # 1. The first transformer's auto-detected config is reused by the second transformer + # in cascade models (e.g. Wan2.2-T2V-A14B), preventing stale/wrong num_mxfp8_blocks + # from an independent read of transformer_2/config.json. + # 2. od_config.quantization_config becomes non-None so _check_unloaded_weights can + # filter expected quantization suffixes instead of raising on every unloaded param. + if quant_config is None and "quantization_config" in config: + from vllm_omni.quantization.factory import build_quant_config + + disk_qc = config["quantization_config"] + if isinstance(disk_qc, dict) and "quant_method" in disk_qc: + qc_method = disk_qc["quant_method"] + qc_kwargs = {k: v for k, v in disk_qc.items() if k != "quant_method"} + quant_config = build_quant_config(qc_method, **qc_kwargs) + self.od_config.quantization_config = quant_config + logger.info( + "Auto-detected quantization from transformer config.json and propagated to od_config: " + "method=%s kwargs=%s", + qc_method, + qc_kwargs, + ) + elif isinstance(disk_qc, str): + quant_config = build_quant_config(disk_qc) + self.od_config.quantization_config = quant_config + logger.info( + "Auto-detected quantization from transformer config.json and propagated to od_config: method=%s", + disk_qc, + ) + return create_transformer_from_config(config, quant_config=quant_config) @property diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py index ea7b8a7fe6e..c00718ac966 100644 --- a/vllm_omni/quantization/factory.py +++ b/vllm_omni/quantization/factory.py @@ -48,6 +48,20 @@ def _build_mxfp8(**kw: Any) -> QuantizationConfig: return DiffusionMXFP8Config(**kw) +def _build_mxfp4(**kw: Any) -> QuantizationConfig: + """Lazy import for W4A4 MXFP4 diffusion config (NPU only).""" + from .mxfp4_config import DiffusionMXFP4Config + + return DiffusionMXFP4Config(**kw) + + +def _build_mxfp8_mxfp4_dualscale(**kw: Any) -> QuantizationConfig: + """Lazy import for MXFP8 (early blocks) + MXFP4 dual-scale (later blocks) config (NPU only).""" + from .mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + return DiffusionMXFP8MXFP4DualScaleConfig(**kw) + + def _build_inc(**kw: Any) -> QuantizationConfig: """Lazy import for INC/AutoRound config with checkpoint kwarg normalization.""" from .inc_config import OmniINCConfig @@ -66,6 +80,8 @@ def _build_inc(**kw: Any) -> QuantizationConfig: "gguf": _build_gguf, "int8": _build_int8, "mxfp8": _build_mxfp8, + "mxfp4": _build_mxfp4, + "mxfp8_mxfp4_dualscale": _build_mxfp8_mxfp4_dualscale, "inc": _build_inc, "auto-round": _build_inc, "auto_round": _build_inc, diff --git a/vllm_omni/quantization/mixed_mxfp_config.py b/vllm_omni/quantization/mixed_mxfp_config.py new file mode 100644 index 00000000000..8dde7dec74d --- /dev/null +++ b/vllm_omni/quantization/mixed_mxfp_config.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Mixed-precision quantization configs for diffusion transformers. + +Each class in this file describes one specific combination of quantization methods +applied to different transformer blocks. New combinations should be added here. + +Current configs +--------------- +DiffusionMXFP8MXFP4DualScaleConfig ("mxfp8_mxfp4_dualscale") + Blocks 0..num_mxfp8_blocks-1 → W8A8 MXFP8 + Blocks num_mxfp8_blocks.. → W4A4 MXFP4 dual-scale + + Block-index dispatch requires linear layers to be constructed with a prefix + of the form "blocks.N.*", threaded through WanTransformerBlock in + wan2_2_transformer.py. + + Config injected by merge_mixed_mxfp_checkpoint.py: + { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": , + "is_checkpoint_serialized": true + } +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import torch +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped + +from vllm_omni.platforms import current_omni_platform +from vllm_omni.quantization.mxfp4_config import ( + NPUMxfp4DualScaleLinearMethod, + NPUMxfp4DualScaleOnlineLinearMethod, +) +from vllm_omni.quantization.mxfp8_config import NPUMxfp8LinearMethod, NPUMxfp8OnlineLinearMethod + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +_BLOCK_IDX_RE = re.compile(r"^blocks\.(\d+)\.") + + +def _parse_block_idx(prefix: str) -> int | None: + """Extract block index from prefix like 'blocks.5.attn1.to_q'.""" + m = _BLOCK_IDX_RE.match(prefix) + return int(m.group(1)) if m else None + + +class DiffusionMXFP8MXFP4DualScaleConfig(QuantizationConfig): + """W8A8 MXFP8 (early blocks) + W4A4 MXFP4 dual-scale (remaining blocks). + + Blocks 0 .. num_mxfp8_blocks-1 are quantized with MXFP8. + Blocks num_mxfp8_blocks .. end are quantized with MXFP4 dual-scale. + + offline mode (is_checkpoint_serialized=True): + MXFP8 blocks → NPUMxfp8LinearMethod + MXFP4 blocks → NPUMxfp4DualScaleLinearMethod + + online mode (is_checkpoint_serialized=False): + MXFP8 blocks → NPUMxfp8OnlineLinearMethod + MXFP4 blocks → NPUMxfp4DualScaleOnlineLinearMethod + + Layers with a prefix not matching "blocks.N.*" (e.g. condition_embedder) are + treated as outside the MXFP8 range and fall through to the MXFP4 dual-scale path. + """ + + def __init__( + self, + num_mxfp8_blocks: int, + is_checkpoint_serialized: bool = False, + ignored_layers: list[str] | None = None, + ) -> None: + super().__init__() + self.num_mxfp8_blocks = num_mxfp8_blocks + self.is_checkpoint_serialized = is_checkpoint_serialized + self.ignored_layers = ignored_layers or [] + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp8_mxfp4_dualscale" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def apply_vllm_mapper(self, hf_to_vllm_mapper: WeightsMapper) -> None: + if self.ignored_layers: + self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) + + @classmethod + def from_config(cls, config: dict[str, Any]) -> DiffusionMXFP8MXFP4DualScaleConfig: + num_mxfp8_blocks = cls.get_from_keys_or(config, ["num_mxfp8_blocks"], 0) + is_serialized = cls.get_from_keys_or(config, ["is_checkpoint_serialized"], False) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or(config, ["modules_to_not_convert"], None) + return cls( + num_mxfp8_blocks=num_mxfp8_blocks, + is_checkpoint_serialized=is_serialized, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> QuantizeMethodBase | None: + if not isinstance(layer, LinearBase): + return None + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + + if not current_omni_platform.is_npu(): + raise NotImplementedError( + "DiffusionMXFP8MXFP4DualScaleConfig is currently only supported on NPU (Ascend) platforms." + ) + + block_idx = _parse_block_idx(prefix) + in_mxfp8_range = block_idx is not None and block_idx < self.num_mxfp8_blocks + + if self.is_checkpoint_serialized: + return NPUMxfp8LinearMethod(self) if in_mxfp8_range else NPUMxfp4DualScaleLinearMethod(self) + else: + return NPUMxfp8OnlineLinearMethod(self) if in_mxfp8_range else NPUMxfp4DualScaleOnlineLinearMethod(self) diff --git a/vllm_omni/quantization/mxfp4_config.py b/vllm_omni/quantization/mxfp4_config.py new file mode 100644 index 00000000000..60a4b5ae43d --- /dev/null +++ b/vllm_omni/quantization/mxfp4_config.py @@ -0,0 +1,558 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""W4A4 MXFP4 (Microscaling FP4) online/offline quantization for diffusion transformers. + +Architecture mirrors mxfp8_config.py: + + MXFPLinearMethodBase – platform-agnostic skeleton (imported from mxfp8_config) + NPUMxfp4LinearMethod – NPU single-scale offline (W4A4 MXFP4) + NPUMxfp4OnlineLinearMethod – NPU single-scale online (BF16 → FP4) + NPUMxfp4DualScaleLinearMethod – NPU dual-scale offline (W4A4 MXFP4 DualScale) + NPUMxfp4DualScaleOnlineLinearMethod – NPU dual-scale online (BF16 → FP4) + +Key differences from MXFP8: + + 1. Precision: float4_e2m1fn_x2 (FP4 packed, 2 values per element). + npu_dynamic_mx_quant(x) without dst_type defaults to float4_e2m1fn_x2. + + 2. Weight layout: stored as (N, K) — NOT pre-transposed. + FP4 uses a packed format; transposing a packed tensor is not safely contiguous. + Transpose is done inline in _quant_matmul via layer.weight.transpose(0, 1). + + 3. GEMM signature: npu_quant_matmul requires explicit + x1_dtype=float4_e2m1fn_x2 and x2_dtype=float4_e2m1fn_x2. + + Scale layout: (N, S/2, 2) — same reshape as MXFP8, also NOT pre-transposed; + transposed inline in _quant_matmul. + +Dual-scale (W4A4_MXFP4_DUALSCALE): + + Two-level quantization: fine scale (per-32 K) + coarse scale (per-512 K) + per-channel + activation pre-scale (mul_scale from calibration). Uses npu_dynamic_dual_level_mx_quant + and npu_dual_level_quant_matmul. Weight stored in NZ hardware format via npu_format_cast. + + Checkpoint tensor shapes for dual-scale: + weight : (N, K) float8_e4m3fn – FP4 packed + weight_scale : (N, K//32) uint8 – fine scale (float8_e8m0fnu bits) + weight_dual_scale : (N, K//512, 1) float32 – coarse scale (extra dim avoids shape assert) + mul_scale : (K,) float32 – per-input-channel activation pre-scale + +Reference: MindIE-SD W4A4MXFP4QuantLinear / W4A4MXFP4DualQuantLinear (mindiesd/quantization/layer.py). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch +from torch.nn import Module +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + LinearBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import _copy_missing_attrs +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight +from vllm.model_executor.parameter import ModelWeightParameter +from vllm.model_executor.utils import replace_parameter + +from vllm_omni.platforms import current_omni_platform +from vllm_omni.quantization.mxfp8_config import ( + MXFPLinearMethodBase, + _LazyWeightMixin, +) + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +class DiffusionMXFP4Config(QuantizationConfig): + """W4A4 MXFP4 quantization config for diffusion transformers. + + Supports both online (BF16 checkpoint → quantize at load time) and offline + (pre-quantized MXFP4 checkpoint) modes, mirroring DiffusionMXFP8Config. + + MX (microscaling) format: groups of 32 K-dimension elements share one + float8_e8m0fnu exponent scale. Weight and activation are float4_e2m1fn_x2. + """ + + def __init__( + self, + is_checkpoint_mxfp4_serialized: bool = False, + ignored_layers: list[str] | None = None, + ) -> None: + super().__init__() + self.is_checkpoint_mxfp4_serialized = is_checkpoint_mxfp4_serialized + self.ignored_layers = ignored_layers or [] + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def apply_vllm_mapper(self, hf_to_vllm_mapper: WeightsMapper) -> None: + if self.ignored_layers: + self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) + + @classmethod + def from_config(cls, config: dict[str, Any]) -> DiffusionMXFP4Config: + is_serialized = cls.get_from_keys_or(config, ["is_checkpoint_mxfp4_serialized"], False) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or(config, ["modules_to_not_convert"], None) + return cls( + is_checkpoint_mxfp4_serialized=is_serialized, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> QuantizeMethodBase | None: + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + if current_omni_platform.is_npu(): + if self.is_checkpoint_mxfp4_serialized: + return NPUMxfp4LinearMethod(self) + return NPUMxfp4OnlineLinearMethod(self) + raise NotImplementedError( + "DiffusionMXFP4Config (W4A4 MXFP4) is currently only supported " + "on NPU (Ascend) platforms. CUDA support is not yet implemented." + ) + return None + + +# --------------------------------------------------------------------------- +# NPU MXFP4 single-scale offline method (pre-quantized checkpoint) +# --------------------------------------------------------------------------- + + +class NPUMxfp4LinearMethod(MXFPLinearMethodBase): + """NPU W4A4 MXFP4 offline linear method for pre-quantized checkpoints. + + Weight canonical layout after process_weights_after_loading: + weight : (N, K) in float4_e2m1fn_x2 — NOT pre-transposed (FP4 packed) + weight_scale: (N, S/2, 2) in float8_e8m0fnu — NOT pre-transposed + + Both are transposed inline in _quant_matmul, unlike MXFP8 which pre-transposes. + NPUMxfp4OnlineLinearMethod normalizes to the same layout so apply() is shared. + """ + + def __init__(self, quant_config: DiffusionMXFP4Config) -> None: + self.quant_config = quant_config + self.out_dtype = torch.get_default_dtype() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # BF16 placeholder; cast to float4_e2m1fn_x2 in process_weights. + layer.register_parameter( + "weight", + ModelWeightParameter( + data=torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + # Scale stored as uint8 in safetensors (float8_e8m0fnu is same bit width). + # Using uint8 avoids a lossy float32 round-trip when loading the checkpoint. + num_groups = (input_size_per_partition + 31) // 32 + layer.register_parameter( + "weight_scale", + ModelWeightParameter( + data=torch.empty(output_size_per_partition, num_groups, dtype=torch.uint8), + input_dim=None, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + import torch_npu + + # NPU: cast to float4_e2m1fn_x2. Weight stays (N, K) — no pre-transpose. + w = layer.weight + if w.dtype != torch_npu.float4_e2m1fn_x2: + w = torch_npu.npu_dtype_cast(w.npu(), torch_npu.float4_e2m1fn_x2) + + # Scale: checkpoint stores uint8 bytes that ARE float8_e8m0fnu bits. + # Only convert if neither uint8 nor the target NPU dtype already. + # (N, K_groups) → (N, K_groups/2, 2). Not pre-transposed; done inline in _quant_matmul. + s = layer.weight_scale.data + if s.dtype not in (torch.uint8, torch_npu.float8_e8m0fnu): + s = s.to(torch_npu.float8_e8m0fnu) + N, K_groups = s.shape + if K_groups % 2 == 1: + s = torch.cat([s, torch.zeros(N, 1, dtype=s.dtype, device=s.device)], dim=1) + s = s.reshape(N, -1, 2).contiguous() + + replace_parameter(layer, "weight", w) + replace_parameter(layer, "weight_scale", s) + layer._already_called_process_weights_after_loading = True + + # --- NPU MXFP4 ops — shared with online path via inheritance --- + + def _quantize_activation(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + import torch_npu + + # No dst_type: npu_dynamic_mx_quant defaults to float4_e2m1fn_x2. + return torch_npu.npu_dynamic_mx_quant(x) + + def _quant_matmul( + self, + x_q: torch.Tensor, + x_scale: torch.Tensor, + layer: torch.nn.Module, + bias: torch.Tensor | None, + ori_dtype: torch.dtype, + ) -> torch.Tensor: + import torch_npu + + if bias is not None and bias.dtype != torch.float32: + bias = bias.to(torch.float32) + # FP4 differences vs FP8: + # weight (N,K) transposed inline → (K,N); scale (N,S/2,2) transposed inline → (S/2,N,2). + # x1_dtype / x2_dtype required — FP4 dtype not inferred from tensor dtype. + return torch_npu.npu_quant_matmul( + x_q, + layer.weight.transpose(0, 1), # (K, N) inline + layer.weight_scale.transpose(0, 1), # (S/2, N, 2) inline + scale_dtype=torch_npu.float8_e8m0fnu, + x1_dtype=torch_npu.float4_e2m1fn_x2, + x2_dtype=torch_npu.float4_e2m1fn_x2, + pertoken_scale=x_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias, + output_dtype=ori_dtype, + group_sizes=[1, 1, 32], + ) + + +# --------------------------------------------------------------------------- +# NPU MXFP4 single-scale online method (BF16 checkpoint → quantize at load time) +# --------------------------------------------------------------------------- + + +class NPUMxfp4OnlineLinearMethod(_LazyWeightMixin, NPUMxfp4LinearMethod): + """NPU W4A4 MXFP4 online linear method. + + MRO: NPUMxfp4OnlineLinearMethod → _LazyWeightMixin → NPUMxfp4LinearMethod + → MXFPLinearMethodBase → LinearMethodBase + + create_weights : _LazyWeightMixin (meta device + patched loader) + process_weights : NPUMxfp4OnlineLinearMethod (BF16 → FP4 + normalize) + apply / ops : NPUMxfp4LinearMethod / MXFPLinearMethodBase (shared) + """ + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + import torch_npu + + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + + # NPU: quantize BF16/FP16 (N, K) → FP4. No dst_type → float4_e2m1fn_x2. + weight_fp4, weight_scale_raw = torch_npu.npu_dynamic_mx_quant(layer.weight) + + # Weight stays (N, K) — no pre-transpose for FP4 packed format. + # Scale: (N, S) → (N, S/2, 2). Not pre-transposed; done inline in _quant_matmul. + weight_scale = weight_scale_raw.reshape(weight_scale_raw.shape[0], -1, 2).contiguous() + + replace_parameter(layer, "weight", weight_fp4) + replace_parameter(layer, "weight_scale", weight_scale) + layer._already_called_process_weights_after_loading = True + + +# --------------------------------------------------------------------------- +# NPU MXFP4 dual-scale offline method (W4A4_MXFP4_DUALSCALE checkpoint) +# --------------------------------------------------------------------------- + + +class NPUMxfp4DualScaleLinearMethod(MXFPLinearMethodBase): + """NPU W4A4 MXFP4 dual-scale offline method for pre-quantized checkpoints. + + Checkpoint tensors and their canonical post-load shapes: + weight : (N, K) float8_e4m3fn – FP4 packed (2 values per byte) + weight_scale : (N, K//32) uint8 – fine scale (float8_e8m0fnu bits); reshaped to (N, K//64, 2) + weight_dual_scale : (N, K//512, 1) float32 – coarse scale; transposed to (K//512, N) + mul_scale : (K,) float32 – per-input-channel activation pre-scale (from calibration) + + Forward pass: + x_q, l0, l1 = npu_dynamic_dual_level_mx_quant(x, smooth_scale=mul_scale) + out = npu_dual_level_quant_matmul(x_q, weight, l0, weight_dual_scale, l1, weight_scale) + + Reference: MindIE-SD W4A4MXFP4DualQuantLinear. + """ + + def __init__(self, quant_config: Any) -> None: + self.quant_config = quant_config + self.out_dtype = torch.get_default_dtype() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # FP4 packed: 2 values per float8_e4m3fn byte → checkpoint stores as float8_e4m3fn; register as BF16 + layer.register_parameter( + "weight", + ModelWeightParameter( + data=torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + # Fine scale: one uint8 exponent (float8_e8m0fnu bit pattern) per group of 32 K elements. + num_groups_fine = (input_size_per_partition + 31) // 32 + layer.register_parameter( + "weight_scale", + ModelWeightParameter( + data=torch.empty(output_size_per_partition, num_groups_fine, dtype=torch.uint8), + input_dim=None, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + # Coarse scale: one float32 per group of 512 K elements. + # Shape (N, K_coarse, 1) matches checkpoint layout exactly, avoiding the + # shape-mismatch assert in linear.py:1344. + num_groups_coarse = (input_size_per_partition + 511) // 512 + layer.register_parameter( + "weight_dual_scale", + ModelWeightParameter( + data=torch.empty(output_size_per_partition, num_groups_coarse, 1, dtype=torch.float32), + input_dim=None, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + # mul_scale is a float32 calibration tensor; register as float32 + # to avoid precision loss from an implicit BF16 cast during weight loading. + layer.register_parameter( + "mul_scale", + ModelWeightParameter( + data=torch.empty(input_size_per_partition, dtype=torch.float32), + input_dim=None, + output_dim=None, + weight_loader=weight_loader, + ), + ) + setattr(layer.mul_scale, "ignore_warning", True) + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + import torch_npu + + # float8_e4m3fn (FP4 packed) → float4_e2m1fn_x2 → NZ hardware format (format ID 29). + # NZ layout is required by npu_dual_level_quant_matmul for FP4 weight matrices. + w = torch_npu.npu_dtype_cast(layer.weight.data.npu(), torch_npu.float4_e2m1fn_x2) + w = torch_npu.npu_format_cast(w.view(torch.int8), 29, customize_dtype=torch.int8) + + # Fine scale: (N, K//32) uint8 → cast to float8_e8m0fnu → (N, K//64, 2). + # The cast reinterprets the stored uint8 bit-patterns as float8_e8m0fnu exponents, + # which is required for npu_dual_level_quant_matmul to compute correct scale values. + s = layer.weight_scale.data.npu() + s = s.reshape(s.shape[0], -1, 2).contiguous() + + # Coarse scale: (N, K//512, 1) → squeeze → (N, K//512) → transpose → (K//512, N). + ds = layer.weight_dual_scale.data.squeeze(-1).transpose(0, 1).npu().contiguous() + + ms = layer.mul_scale.to(torch.bfloat16).data.view(-1).npu().contiguous() + + replace_parameter(layer, "weight", w) + replace_parameter(layer, "weight_scale", s) + replace_parameter(layer, "weight_dual_scale", ds) + replace_parameter(layer, "mul_scale", ms) + layer._already_called_process_weights_after_loading = True + + def _apply_inner( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None, + ori_dtype: torch.dtype, + ) -> torch.Tensor: + """Dual-scale inner loop: dtype guard → 3-tuple quantize (mul_scale as smooth) → matmul. + + Overrides the default single-scale _apply_inner from MXFPLinearMethodBase. + apply() in the base class handles reshape/unreshape; this method is not + responsible for that. + """ + if ori_dtype not in (torch.bfloat16, torch.float16): + x = x.to(torch.bfloat16) + x_q, l0_scale, l1_scale = self._quantize_activation(x, layer.mul_scale) + return self._quant_matmul(x_q, l0_scale, l1_scale, layer, bias, ori_dtype) + + def _quantize_activation( # type: ignore[override] + self, + x: torch.Tensor, + smooth_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + import torch_npu + + return torch_npu.npu_dynamic_dual_level_mx_quant(x, smooth_scale=smooth_scale) + + def _quant_matmul( # type: ignore[override] + self, + x_q: torch.Tensor, + l0_scale: torch.Tensor, + l1_scale: torch.Tensor, + layer: torch.nn.Module, + bias: torch.Tensor | None, + ori_dtype: torch.dtype, + ) -> torch.Tensor: + import torch_npu + + if bias is not None and bias.dtype != torch.float32: + bias = bias.to(torch.float32) + # weight_scale is (N, K//64, 2) float8_e8m0fnu — operator expects output-major layout. + # weight_dual_scale is (K//512, N) float32 — transposed to K-major in process_weights. + return torch_npu.npu_dual_level_quant_matmul( + x_q, + layer.weight, + l0_scale, + layer.weight_dual_scale, + l1_scale, + layer.weight_scale, + bias=bias, + output_dtype=ori_dtype, + ) + + +# --------------------------------------------------------------------------- +# NPU MXFP4 dual-scale online method (BF16 checkpoint → quantize at load time) +# --------------------------------------------------------------------------- + + +class NPUMxfp4DualScaleOnlineLinearMethod(_LazyWeightMixin, NPUMxfp4DualScaleLinearMethod): + """NPU W4A4 MXFP4 dual-scale online method: quantises BF16 weights at load time. + + MRO: NPUMxfp4DualScaleOnlineLinearMethod → _LazyWeightMixin + → NPUMxfp4DualScaleLinearMethod → MXFPLinearMethodBase + + create_weights : _LazyWeightMixin (meta device + patched loader) + process_weights : NPUMxfp4DualScaleOnlineLinearMethod (BF16 → FP4 + dual scales) + apply / _quant_matmul : NPUMxfp4DualScaleLinearMethod (shared with offline path) + """ + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + import torch_npu + + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + + # Quantize BF16 weight → FP4 + dual-level scales (no smooth pre-scale for online). + # Returns: (weight_fp4, l0_scale[coarse per-512], l1_scale[fine per-32]) + weight_fp4, w_l0_scale, w_l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + layer.weight.data.npu(), smooth_scale=None + ) + + # NZ hardware format for the FP4 weight (same as offline path). + w = torch_npu.npu_format_cast(weight_fp4.view(torch.int8), 29, customize_dtype=torch.int8) + + # Fine scale (l1): (N, K//32) → (N, K//64, 2). Dtype from op output (float8_e8m0fnu). + s = w_l1_scale.reshape(w_l1_scale.shape[0], -1, 2).contiguous() + + # Coarse scale (l0): (N, K_coarse) → (K_coarse, N). Dtype from op output. + ds = w_l0_scale.reshape(w_l0_scale.shape[0], -1).transpose(0, 1).contiguous() + + # No calibration available: identity pre-scale (no smooth quantization effect). + ms = torch.ones(layer.input_size_per_partition, dtype=torch.float32, device="npu") + + replace_parameter(layer, "weight", w) + replace_parameter(layer, "weight_scale", s) + replace_parameter(layer, "weight_dual_scale", ds) + replace_parameter(layer, "mul_scale", ms) + layer._already_called_process_weights_after_loading = True diff --git a/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py new file mode 100644 index 00000000000..37faf67b0bd --- /dev/null +++ b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +"""Merge mixed MXFP8 + W4A4_MXFP4_DUALSCALE quantized Wan2.2 weights into HF Diffusers format. + +msModelSlim produces a mixed-precision checkpoint where transformer blocks are split into: + - Early blocks (0..num_mxfp8_blocks-1): W8A8_MXFP8 + - Remaining blocks (num_mxfp8_blocks..): W4A4_MXFP4_DUALSCALE + +MXFP4_DUALSCALE key structure per linear layer +----------------------------------------------- + blocks.N.X.linear.weight W4A4_MXFP4_DUALSCALE – int8 (FP4 packed) + blocks.N.X.linear.weight_scale W4A4_MXFP4_DUALSCALE – uint8 (float8_e8m0fnu fine scale, per-32K) + blocks.N.X.linear.weight_dual_scale W4A4_MXFP4_DUALSCALE – float32 (coarse scale, per-512K) + blocks.N.X.linear.bias FLOAT – bias (if present) + blocks.N.X.div.mul_scale FLOAT – float32 per-input-channel activation pre-scale + +MXFP8 key structure (no wrapper, same as merge_mxfp8_checkpoint.py): + blocks.N.X.weight W8A8_MXFP8 + blocks.N.X.weight_scale W8A8_MXFP8 + +Self-attention QKV notes +------------------------ +Self-attention Q/K/V weights are separate in the checkpoint (self_attn.q/k/v) but fused +into a single to_qkv layer in vllm-omni. The transformer's load_weights() handles this via +stacked_params_mapping. This script keeps Q/K/V keys separate — do NOT pre-fuse them here. + +For mul_scale specifically: even though Q/K/V process the same input (same mul_scale value), +they are kept separate. load_weights() routes all three to the same to_qkv.mul_scale parameter +and each overwrites the previous. Since Q=K=V for mul_scale, the final value is correct. + +NOTE: Pre-fusing as to_qkv.mul_scale would BREAK loading because ".attn1.to_q" is a +substring of ".attn1.to_qkv", causing load_weights() stacked_params_mapping to produce a +garbage key ("to_qkvkv") that is not in params_dict, triggering a break that skips the +direct-load else branch entirely. + +Supported model types: + - Wan2.2-T2V-A14B (MoE cascade: transformer + transformer_2) + - Wan2.2-I2V-A14B (MoE cascade: transformer + transformer_2) + - Wan2.2-TI2V-5B (single transformer) + +Usage: + python merge_mxfp4_dualscale_checkpoint.py \\ + --model-type Wan2.2-T2V-A14B \\ + --original-model /path/to/Wan2.2-T2V-A14B-Diffusers \\ + --quant-path /path/to/msmodelslim-output \\ + --output-path /path/to/merged-output \\ + --num-mxfp8-blocks 5 # auto-detected if omitted +""" + +from __future__ import annotations + +import argparse +import json +import pathlib +import re +import shutil +import warnings +from typing import Any + +import torch +from safetensors.torch import load_file, save_file + +# --------------------------------------------------------------------------- +# Key rename: msModelSlim naming → Diffusers / vllm-omni naming +# Identical to merge_mxfp8_checkpoint.py; applied to all blocks uniformly. +# --------------------------------------------------------------------------- + +TRANSFORMER_KEYS_RENAME_DICT: dict[str, str] = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Norm order swap (quant tool: norm1, norm3, norm2 → diffusers: norm1, norm2, norm3) + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # Self-attention + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + # Cross-attention + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", + # I2V image embedder + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", +} + +SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] +CASCADE_MODEL_TYPES = {"Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B"} + +# Suffixes that appear inside .linear.* wrapper for MXFP4 tensors. +_LINEAR_ATTRS = ("weight_dual_scale", "weight_scale", "weight", "bias") + +_BLOCK_IDX_RE = re.compile(r"^blocks\.(\d+)\.") + + +# --------------------------------------------------------------------------- +# Key transformation helpers +# --------------------------------------------------------------------------- + + +def _parse_block_idx(key: str) -> int | None: + """Extract block index from a key like 'blocks.5.attn1.to_q.weight'.""" + m = _BLOCK_IDX_RE.match(key) + return int(m.group(1)) if m else None + + +def _apply_rename_dict(key: str) -> str: + for src, dst in TRANSFORMER_KEYS_RENAME_DICT.items(): + key = key.replace(src, dst) + return key + + +def _strip_mxfp4_wrapper(key: str) -> str: + """Strip .linear.ATTR or .div.mul_scale wrappers added by msModelSlim. + + MXFP4 tensors are wrapped in sub-modules: + X.linear.weight / weight_scale / weight_dual_scale / bias + X.div.mul_scale + + MXFP8 and FLOAT tensors have no wrappers — this function is a no-op for them. + + Examples after apply_rename_dict: + attn1.to_q.linear.weight → attn1.to_q.weight + attn1.to_q.linear.weight_scale → attn1.to_q.weight_scale + attn1.to_q.div.mul_scale → attn1.to_q.mul_scale + attn1.norm_q.weight → attn1.norm_q.weight (unchanged) + """ + # Check longest attribute names first to avoid partial suffix matches. + for attr in _LINEAR_ATTRS: + suffix = f".linear.{attr}" + if key.endswith(suffix): + return key[: -len(suffix)] + f".{attr}" + if key.endswith(".div.mul_scale"): + return key[: -len(".div.mul_scale")] + ".mul_scale" + return key + + +# --------------------------------------------------------------------------- +# Quantization metadata helpers +# --------------------------------------------------------------------------- + +# Known quantized weight types (FLOAT tensors don't determine block type). +_MXFP8_TYPE = "mxfp8" +_MXFP4_DUALSCALE_TYPE = "mxfp4_dualscale" + + +def _classify_blocks(quant_meta: dict[str, str]) -> dict[int, str]: + """Classify each transformer block by quantization type from quant_meta. + + Returns a dict mapping block_idx → 'mxfp8' | 'mxfp4_dualscale'. + A block's type is determined by the first quantized (non-FLOAT) tensor found for it. + """ + block_types: dict[int, str] = {} + for key, qtype in quant_meta.items(): + idx = _parse_block_idx(key) + if idx is None or idx in block_types: + continue + if qtype.startswith("W8A8_MXFP8"): + block_types[idx] = _MXFP8_TYPE + elif qtype.startswith("W4A4_MXFP4_DUALSCALE"): + block_types[idx] = _MXFP4_DUALSCALE_TYPE + return block_types + + +def _print_block_summary(block_types: dict[int, str]) -> None: + """Print a compact run-length summary of the block layout.""" + if not block_types: + print(" Block layout: (empty)") + return + + sorted_indices = sorted(block_types) + runs: list[tuple[int, int, str]] = [] + run_start = sorted_indices[0] + run_type = block_types[run_start] + for idx in sorted_indices[1:]: + if block_types[idx] != run_type: + runs.append((run_start, idx - 1, run_type)) + run_start = idx + run_type = block_types[idx] + runs.append((run_start, sorted_indices[-1], run_type)) + + print(f" Block layout ({len(sorted_indices)} blocks classified):") + for start, end, btype in runs: + count = end - start + 1 + range_str = f"{start}" if start == end else f"{start}–{end}" + print(f" blocks {range_str:>8}: {btype} ({count} block{'s' if count > 1 else ''})") + + +def _detect_num_mxfp8_blocks(quant_meta: dict[str, str]) -> int: + """Count leading MXFP8 blocks (blocks 0..N-1). + + Two cases handled: + - MXFP8 blocks present in quant_meta (W8A8_MXFP8 markers): + count the consecutive run from block 0. + - MXFP8 blocks absent from quant_meta (msModelSlim may omit them): + the index of the first MXFP4_DUALSCALE block equals num_mxfp8_blocks, + because all blocks before it are implicitly MXFP8. + """ + block_types = _classify_blocks(quant_meta) + if not block_types: + return 0 + + sorted_indices = sorted(block_types) + first_idx = sorted_indices[0] + + if block_types[first_idx] == _MXFP8_TYPE: + # MXFP8 blocks present: count consecutive run starting at block 0. + if first_idx != 0: + warnings.warn( + f"First classified block is {first_idx} (expected 0); " + "cannot determine num_mxfp8_blocks reliably. Returning 0." + ) + return 0 + count = 0 + for idx in sorted_indices: + if block_types[idx] == _MXFP8_TYPE: + count += 1 + else: + break + return count + + # MXFP8 blocks absent from quant_meta: the first MXFP4 block index is the boundary. + return first_idx + + +# --------------------------------------------------------------------------- +# Safetensors I/O +# --------------------------------------------------------------------------- + + +def _load_safetensors_dir(directory: pathlib.Path, glob: str = "*.safetensors") -> dict[str, torch.Tensor]: + candidates = sorted(directory.glob(glob)) + if not candidates: + raise FileNotFoundError(f"No safetensors matching '{glob}' found in {directory}") + state: dict[str, torch.Tensor] = {} + for f in candidates: + state.update(load_file(str(f))) + return state + + +def _load_quant_safetensors(directory: pathlib.Path) -> dict[str, torch.Tensor]: + try: + return _load_safetensors_dir(directory, "quant_model_weight*.safetensors") + except FileNotFoundError: + return _load_safetensors_dir(directory) + + +def _load_quant_meta(directory: pathlib.Path) -> dict[str, str]: + candidates = sorted(directory.glob("quant_model_description*.json")) + if not candidates: + print(f" WARNING: No quant_model_description*.json in {directory}; treating all tensors as FLOAT.") + return {} + with open(candidates[0]) as f: + return json.load(f) + + +# --------------------------------------------------------------------------- +# Per-transformer conversion +# --------------------------------------------------------------------------- + + +def _convert_transformer( + quant_subdir: pathlib.Path, + output_dir: pathlib.Path, + original_transformer_dir: pathlib.Path, + num_mxfp8_blocks: int | None, +) -> int: + """Convert one transformer directory. Returns the resolved num_mxfp8_blocks.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # BF16 base: ensures non-quantized tensors that msModelSlim might omit are present. + print(f" Loading BF16 base from {original_transformer_dir} …") + base_state = _load_safetensors_dir(original_transformer_dir) + print(f" {len(base_state)} BF16 tensors loaded") + + print(f" Loading quantized weights from {quant_subdir} …") + quant_state = _load_quant_safetensors(quant_subdir) + quant_meta = _load_quant_meta(quant_subdir) + print(f" {len(quant_state)} quant tensors, {len(quant_meta)} meta entries") + + # Classify blocks and auto-detect / validate num_mxfp8_blocks. + block_types = _classify_blocks(quant_meta) + detected = _detect_num_mxfp8_blocks(quant_meta) + # Fill in inferred MXFP8 blocks (may be absent from quant_meta). + for i in range(detected): + block_types.setdefault(i, _MXFP8_TYPE) + _print_block_summary(block_types) + if num_mxfp8_blocks is None: + num_mxfp8_blocks = detected + print(f" Auto-detected num_mxfp8_blocks = {num_mxfp8_blocks}") + elif num_mxfp8_blocks != detected: + warnings.warn(f"--num-mxfp8-blocks={num_mxfp8_blocks} but auto-detected {detected}. Using the provided value.") + + # Remap all quantized keys. + # Key transformation is per-block: + # MXFP8 blocks → rename dict only (no .linear./.div. wrappers) + # MXFP4_DUALSCALE blocks → rename dict + strip .linear./.div. wrappers + # Non-block keys → rename dict only + remapped: dict[str, torch.Tensor] = {} + remapped_meta: dict[str, str] = {} + skipped: list[str] = [] + + for key, tensor in quant_state.items(): + renamed = _apply_rename_dict(key) + + block_idx = _parse_block_idx(renamed) + if block_idx is not None and block_types.get(block_idx) == _MXFP4_DUALSCALE_TYPE: + final_key = _strip_mxfp4_wrapper(renamed) + else: + final_key = renamed + + # Skip non-tensor metadata keys that msModelSlim sometimes embeds + # (e.g. quant_type markers stored as scalar tensors). + if final_key.endswith(".quant_type"): + skipped.append(key) + continue + + remapped[final_key] = tensor + if key in quant_meta: + remapped_meta[final_key] = quant_meta[key] + + if skipped: + print(f" Skipped {len(skipped)} metadata keys (quant_type markers): {skipped[:5]}") + + # Overlay: BF16 base provides the scaffold; quant tensors replace their BF16 counterparts + # and add the new scale tensors (weight_scale, weight_dual_scale, mul_scale). + merged = {**base_state, **remapped} + + # Save weights. + out_weights = output_dir / "diffusion_pytorch_model.safetensors" + save_file(merged, str(out_weights)) + print(f" Saved {len(merged)} tensors → {out_weights}") + + # Save remapped quant metadata (for inspection / debugging). + out_meta_path = output_dir / "quant_model_description.json" + with open(out_meta_path, "w") as f: + json.dump(remapped_meta, f, indent=2) + + # Inject quantization_config into config.json. + src_config = original_transformer_dir / "config.json" + if src_config.is_file(): + with open(src_config) as f: + config = json.load(f) + config["quantization_config"] = _build_quant_config(num_mxfp8_blocks) + out_config = output_dir / "config.json" + with open(out_config, "w") as f: + json.dump(config, f, indent=2) + print( + f" Injected quantization_config " + f"(mxfp8_mxfp4_dualscale, num_mxfp8_blocks={num_mxfp8_blocks}) " + f"→ {out_config}" + ) + else: + print(f" WARNING: No config.json at {src_config}; quantization_config not injected.") + + return num_mxfp8_blocks + + +def _build_quant_config(num_mxfp8_blocks: int) -> dict[str, Any]: + return { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": num_mxfp8_blocks, + "is_checkpoint_serialized": True, + } + + +# --------------------------------------------------------------------------- +# Model-type helpers +# --------------------------------------------------------------------------- + + +def _get_transformer_dirs(model_type: str) -> list[str]: + return ["transformer", "transformer_2"] if model_type in CASCADE_MODEL_TYPES else ["transformer"] + + +def _get_quant_subdir(model_type: str, quant_path: pathlib.Path, transformer_dir: str) -> pathlib.Path: + if model_type in CASCADE_MODEL_TYPES: + sub = "high_noise_model" if transformer_dir == "transformer" else "low_noise_model" + return quant_path / sub + return quant_path + + +# --------------------------------------------------------------------------- +# Main repack +# --------------------------------------------------------------------------- + + +def repack( + model_type: str, + original_model_path: pathlib.Path, + quant_path: pathlib.Path, + output_path: pathlib.Path, + num_mxfp8_blocks: int | None, +) -> None: + transformer_dirs = _get_transformer_dirs(model_type) + + print(f"Copying original model to {output_path} (skipping {transformer_dirs}) …") + shutil.copytree( + str(original_model_path), + str(output_path), + ignore=shutil.ignore_patterns(*transformer_dirs), + ) + + for tdir in transformer_dirs: + q_subdir = _get_quant_subdir(model_type, quant_path, tdir) + out_tdir = output_path / tdir + orig_tdir = original_model_path / tdir + print(f"\nConverting {tdir} (quant source: {q_subdir.name}) …") + resolved_n = _convert_transformer(q_subdir, out_tdir, orig_tdir, num_mxfp8_blocks) + # Use the resolved value for subsequent transformers in the same cascade. + num_mxfp8_blocks = resolved_n + + print(f"\nDone. Merged model → {output_path}") + print("\nRun inference (quantization auto-detected from config.json):") + print(f" python text_to_video.py --model {output_path}") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--model-type", required=True, choices=SUPPORTED_MODEL_TYPES, help="Model variant.") + parser.add_argument("--original-model", required=True, help="Original HF Diffusers model directory (BF16).") + parser.add_argument("--quant-path", required=True, help="msModelSlim quantized weights directory.") + parser.add_argument("--output-path", required=True, help="Output directory for merged model.") + parser.add_argument( + "--num-mxfp8-blocks", + type=int, + default=None, + help=("Number of leading MXFP8 blocks (0..N-1). Auto-detected from quant_model_description.json if omitted."), + ) + args = parser.parse_args() + + repack( + model_type=args.model_type, + original_model_path=pathlib.Path(args.original_model), + quant_path=pathlib.Path(args.quant_path), + output_path=pathlib.Path(args.output_path), + num_mxfp8_blocks=args.num_mxfp8_blocks, + ) + + +if __name__ == "__main__": + main() From a2b1a0289961988a621a916fef8046431d5ee895 Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Thu, 14 May 2026 15:24:09 +0800 Subject: [PATCH 2/8] fix para init Signed-off-by: hyh_hh --- vllm_omni/quantization/mixed_mxfp_config.py | 2 +- vllm_omni/quantization/mxfp4_config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_omni/quantization/mixed_mxfp_config.py b/vllm_omni/quantization/mixed_mxfp_config.py index 8dde7dec74d..b4123e52964 100644 --- a/vllm_omni/quantization/mixed_mxfp_config.py +++ b/vllm_omni/quantization/mixed_mxfp_config.py @@ -79,7 +79,7 @@ class DiffusionMXFP8MXFP4DualScaleConfig(QuantizationConfig): def __init__( self, - num_mxfp8_blocks: int, + num_mxfp8_blocks: int = 0, is_checkpoint_serialized: bool = False, ignored_layers: list[str] | None = None, ) -> None: diff --git a/vllm_omni/quantization/mxfp4_config.py b/vllm_omni/quantization/mxfp4_config.py index 60a4b5ae43d..21c48dbdde5 100644 --- a/vllm_omni/quantization/mxfp4_config.py +++ b/vllm_omni/quantization/mxfp4_config.py @@ -549,7 +549,7 @@ def process_weights_after_loading(self, layer: Module) -> None: ds = w_l0_scale.reshape(w_l0_scale.shape[0], -1).transpose(0, 1).contiguous() # No calibration available: identity pre-scale (no smooth quantization effect). - ms = torch.ones(layer.input_size_per_partition, dtype=torch.float32, device="npu") + ms = torch.ones(layer.input_size_per_partition, dtype=torch.bfloat16, device="npu") replace_parameter(layer, "weight", w) replace_parameter(layer, "weight_scale", s) From 3c221e2413a94f39e5c127efa7cc8f7d0046c470 Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Fri, 15 May 2026 15:51:46 +0800 Subject: [PATCH 3/8] add ut Signed-off-by: hyh_hh --- .../wan2_2/test_create_transformer_quant.py | 370 ++++++++++++++++++ .../quantization/test_mxfp4_config.py | 236 +++++++++++ .../quantization/test_mxfp4_key_remap.py | 201 ++++++++++ .../quantization/test_mxfp8_config.py | 308 +++++++++++++++ .../quantization/test_mxfp8_key_remap.py | 214 ++++++++++ vllm_omni/quantization/factory.py | 26 +- vllm_omni/quantization/tools/__init__.py | 2 + .../tools/merge_mxfp4_dualscale_checkpoint.py | 6 +- 8 files changed, 1360 insertions(+), 3 deletions(-) create mode 100644 tests/diffusion/models/wan2_2/test_create_transformer_quant.py create mode 100644 tests/diffusion/quantization/test_mxfp4_config.py create mode 100644 tests/diffusion/quantization/test_mxfp4_key_remap.py create mode 100644 tests/diffusion/quantization/test_mxfp8_config.py create mode 100644 tests/diffusion/quantization/test_mxfp8_key_remap.py create mode 100644 vllm_omni/quantization/tools/__init__.py diff --git a/tests/diffusion/models/wan2_2/test_create_transformer_quant.py b/tests/diffusion/models/wan2_2/test_create_transformer_quant.py new file mode 100644 index 00000000000..4e32ef546cb --- /dev/null +++ b/tests/diffusion/models/wan2_2/test_create_transformer_quant.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Regression tests for transformer quant-config auto-detection and cascade propagation. + +The loader path at pipeline_wan2_2.py carries two quantization contracts: + + create_transformer_from_config (~L137) + - Reads quantization_config from config.json (injected by the merge scripts) + - Auto-detects the quant method when no CLI quant_config is provided + - Rejects method mismatches (CLI vs disk) + - Upgrades online → offline when disk marks is_checkpoint_*_serialized=True + - Rebuilds when the active num_mxfp8_blocks differs from the disk value + + Wan22Pipeline._create_transformer (~L456) + - Propagates the auto-detected config to od_config so the second transformer + in a cascade model reuses the same config rather than re-reading independently + - Does NOT overwrite od_config.quantization_config when it is already set + +All tests are pure-CPU and do not load model weights. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 as wan22_module +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import create_transformer_from_config + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + +# Minimum config that create_transformer_from_config accepts without raising. +_MIN_CFG: dict = {"patch_size": [1, 2, 2]} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_fake_transformer(): + """Return (FakeTransformer class, captured list). + + Each _FakeTransformer.__init__ call appends its **kwargs to captured, + letting tests inspect what quant_config was passed per transformer. + """ + captured: list[dict] = [] + + class _FakeTransformer: + def __init__(self, **kwargs): + captured.append(kwargs) + + return _FakeTransformer, captured + + +class _FakePipeline: + """Minimal stand-in exposing only what _create_transformer needs from self.""" + + def __init__(self, od_config: SimpleNamespace) -> None: + self.od_config = od_config + + # Bind the real unbound method so the tests exercise production code. + _create_transformer = wan22_module.Wan22Pipeline._create_transformer + + +# --------------------------------------------------------------------------- +# create_transformer_from_config — auto-detection +# --------------------------------------------------------------------------- + + +def test_create_transformer_detects_mxfp8_serialized_from_config_json(monkeypatch): + """When config.json carries MXFP8 quant and no CLI quant_config is provided, + the transformer must receive a DiffusionMXFP8Config with serialized=True.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8", + "is_checkpoint_mxfp8_serialized": True, + }, + } + create_transformer_from_config(config) + + qc = captured[0].get("quant_config") + assert isinstance(qc, DiffusionMXFP8Config) + assert qc.is_checkpoint_mxfp8_serialized is True + + +def test_create_transformer_detects_mxfp4_dualscale_from_config_json(monkeypatch): + """config.json with mxfp8_mxfp4_dualscale + num_mxfp8_blocks must produce + a DiffusionMXFP8MXFP4DualScaleConfig with the correct block count and + serialized flag.""" + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 7, + "is_checkpoint_serialized": True, + }, + } + create_transformer_from_config(config) + + qc = captured[0].get("quant_config") + assert isinstance(qc, DiffusionMXFP8MXFP4DualScaleConfig) + assert qc.num_mxfp8_blocks == 7 + assert qc.is_checkpoint_serialized is True + + +def test_create_transformer_without_quantization_config_passes_no_quant(monkeypatch): + """A plain BF16 config.json (no quantization_config key) must result in no + quant_config being passed to WanTransformer3DModel.""" + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + create_transformer_from_config(_MIN_CFG) + + assert "quant_config" not in captured[0] + + +# --------------------------------------------------------------------------- +# create_transformer_from_config — method-mismatch guard +# --------------------------------------------------------------------------- + + +def test_create_transformer_rejects_method_mismatch(monkeypatch): + """Passing a CLI quant_config whose get_name() differs from the config.json + quant_method must raise ValueError immediately (prevents silent weight corruption).""" + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + FakeTransformer, _ = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + fp8_cli = Fp8Config(is_checkpoint_fp8_serialized=True, activation_scheme="dynamic") + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8", + "is_checkpoint_mxfp8_serialized": True, + }, + } + with pytest.raises(ValueError, match="quant_method"): + create_transformer_from_config(config, quant_config=fp8_cli) + + +# --------------------------------------------------------------------------- +# create_transformer_from_config — online → offline upgrade +# --------------------------------------------------------------------------- + + +def test_create_transformer_upgrades_to_serialized_when_disk_marks_it(monkeypatch): + """CLI passes online (is_checkpoint_mxfp8_serialized=False) but config.json + marks is_checkpoint_mxfp8_serialized=True → must switch to offline (serialized) + so that pre-quantized FP8 tensors are loaded correctly.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + online_cli = DiffusionMXFP8Config(is_checkpoint_mxfp8_serialized=False) + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8", + "is_checkpoint_mxfp8_serialized": True, + }, + } + create_transformer_from_config(config, quant_config=online_cli) + + qc = captured[0].get("quant_config") + assert isinstance(qc, DiffusionMXFP8Config) + assert qc.is_checkpoint_mxfp8_serialized is True + + +# --------------------------------------------------------------------------- +# create_transformer_from_config — num_mxfp8_blocks rebuild +# --------------------------------------------------------------------------- + + +def test_create_transformer_rebuilds_when_num_mxfp8_blocks_differs(monkeypatch): + """When the active quant_config has num_mxfp8_blocks=5 but config.json says 10, + the config must be rebuilt from disk so block routing is authoritative for + this specific transformer.""" + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + stale = DiffusionMXFP8MXFP4DualScaleConfig(num_mxfp8_blocks=5, is_checkpoint_serialized=True) + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 10, + "is_checkpoint_serialized": True, + }, + } + create_transformer_from_config(config, quant_config=stale) + + qc = captured[0].get("quant_config") + assert isinstance(qc, DiffusionMXFP8MXFP4DualScaleConfig) + assert qc.num_mxfp8_blocks == 10 + + +def test_create_transformer_does_not_rebuild_when_num_mxfp8_blocks_matches(monkeypatch): + """When the active quant_config already has the correct num_mxfp8_blocks, + the same instance must be passed through unchanged (no unnecessary rebuild).""" + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + matching = DiffusionMXFP8MXFP4DualScaleConfig(num_mxfp8_blocks=5, is_checkpoint_serialized=True) + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 5, + "is_checkpoint_serialized": True, + }, + } + create_transformer_from_config(config, quant_config=matching) + + assert captured[0].get("quant_config") is matching + + +# --------------------------------------------------------------------------- +# Wan22Pipeline._create_transformer — od_config propagation +# --------------------------------------------------------------------------- + + +def test_pipeline_create_transformer_propagates_quant_config_to_od_config(monkeypatch): + """When od_config.quantization_config is None, _create_transformer must + auto-detect the quant method from config.json and propagate the built config + back to od_config so the next call can reuse it.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + FakeTransformer, _ = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + od_config = SimpleNamespace(quantization_config=None) + pipeline = _FakePipeline(od_config) + + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8", + "is_checkpoint_mxfp8_serialized": True, + }, + } + pipeline._create_transformer(config) + + assert isinstance(od_config.quantization_config, DiffusionMXFP8Config) + assert od_config.quantization_config.is_checkpoint_mxfp8_serialized is True + + +def test_pipeline_create_transformer_does_not_overwrite_existing_od_config(monkeypatch): + """If od_config.quantization_config is already set (propagated from the first + transformer), _create_transformer must leave it unchanged — the propagated + config is the authority for the cascade.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + FakeTransformer, _ = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + existing = DiffusionMXFP8Config(is_checkpoint_mxfp8_serialized=True) + od_config = SimpleNamespace(quantization_config=existing) + pipeline = _FakePipeline(od_config) + + config = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8", + "is_checkpoint_mxfp8_serialized": True, + }, + } + pipeline._create_transformer(config) + + assert od_config.quantization_config is existing + + +# --------------------------------------------------------------------------- +# Wan22Pipeline._create_transformer — cascade contracts +# --------------------------------------------------------------------------- + + +def test_pipeline_cascade_both_transformers_get_mxfp8_serialized_config(monkeypatch): + """Cascade model (transformer + transformer_2) with MXFP8 checkpoint: + - First transformer: auto-detects serialized config, propagates to od_config. + - Second transformer: reuses the propagated config (same instance). + Both must receive is_checkpoint_mxfp8_serialized=True.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + od_config = SimpleNamespace(quantization_config=None) + pipeline = _FakePipeline(od_config) + + mxfp8_qc = {"quant_method": "mxfp8", "is_checkpoint_mxfp8_serialized": True} + pipeline._create_transformer({**_MIN_CFG, "quantization_config": mxfp8_qc}) + pipeline._create_transformer({**_MIN_CFG, "quantization_config": mxfp8_qc}) + + assert len(captured) == 2 + for i, kwargs in enumerate(captured): + qc = kwargs.get("quant_config") + assert isinstance(qc, DiffusionMXFP8Config), f"transformer[{i}]: expected DiffusionMXFP8Config, got {type(qc)}" + assert qc.is_checkpoint_mxfp8_serialized is True, f"transformer[{i}]: expected serialized=True" + + # Second transformer must reuse the propagated instance — no unnecessary rebuild. + assert captured[0]["quant_config"] is captured[1]["quant_config"] + + +def test_pipeline_cascade_mxfp4_dualscale_each_transformer_gets_correct_num_blocks(monkeypatch): + """Cascade with mxfp8_mxfp4_dualscale where transformer and transformer_2 have + different num_mxfp8_blocks in their config.json. + + Expected outcome: + transformer → num_mxfp8_blocks=5 (auto-detected, propagated to od_config) + transformer_2 → num_mxfp8_blocks=10 (rebuilt from disk because 10 ≠ 5) + od_config → num_mxfp8_blocks=5 (unchanged; transformer_2's rebuild is local) + """ + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + FakeTransformer, captured = _make_fake_transformer() + monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) + + od_config = SimpleNamespace(quantization_config=None) + pipeline = _FakePipeline(od_config) + + cfg1 = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 5, + "is_checkpoint_serialized": True, + }, + } + cfg2 = { + **_MIN_CFG, + "quantization_config": { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 10, + "is_checkpoint_serialized": True, + }, + } + + pipeline._create_transformer(cfg1) + pipeline._create_transformer(cfg2) + + assert len(captured) == 2 + qc1 = captured[0].get("quant_config") + qc2 = captured[1].get("quant_config") + + assert isinstance(qc1, DiffusionMXFP8MXFP4DualScaleConfig) + assert isinstance(qc2, DiffusionMXFP8MXFP4DualScaleConfig) + assert qc1.num_mxfp8_blocks == 5, f"transformer expected 5 blocks, got {qc1.num_mxfp8_blocks}" + assert qc2.num_mxfp8_blocks == 10, f"transformer_2 expected 10 blocks, got {qc2.num_mxfp8_blocks}" + + # od_config retains the first transformer's config; the rebuild was local. + assert od_config.quantization_config.num_mxfp8_blocks == 5 diff --git a/tests/diffusion/quantization/test_mxfp4_config.py b/tests/diffusion/quantization/test_mxfp4_config.py new file mode 100644 index 00000000000..649298223da --- /dev/null +++ b/tests/diffusion/quantization/test_mxfp4_config.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for MXFP4 quantization configs and the mixed MXFP8+MXFP4 config.""" + +import pytest + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +# --------------------------------------------------------------------------- +# DiffusionMXFP4Config +# --------------------------------------------------------------------------- + + +def test_mxfp4_config_get_name(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + assert DiffusionMXFP4Config.get_name() == "mxfp4" + + +def test_mxfp4_config_from_config_defaults(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + cfg = DiffusionMXFP4Config.from_config({}) + assert cfg.is_checkpoint_mxfp4_serialized is False + assert cfg.ignored_layers == [] + + +def test_mxfp4_config_from_config_serialized(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + cfg = DiffusionMXFP4Config.from_config({"is_checkpoint_mxfp4_serialized": True}) + assert cfg.is_checkpoint_mxfp4_serialized is True + + +def test_mxfp4_config_from_config_ignored_layers(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + cfg = DiffusionMXFP4Config.from_config({"ignored_layers": ["proj_out"]}) + assert cfg.ignored_layers == ["proj_out"] + + +def test_mxfp4_config_from_config_modules_to_not_convert_fallback(): + """modules_to_not_convert must be accepted as an alias for ignored_layers.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + cfg = DiffusionMXFP4Config.from_config({"modules_to_not_convert": ["proj_out"]}) + assert cfg.ignored_layers == ["proj_out"] + + +# --------------------------------------------------------------------------- +# DiffusionMXFP8MXFP4DualScaleConfig +# --------------------------------------------------------------------------- + + +def test_mixed_config_get_name(): + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + assert DiffusionMXFP8MXFP4DualScaleConfig.get_name() == "mxfp8_mxfp4_dualscale" + + +def test_mixed_config_no_args_does_not_raise(): + """DiffusionMXFP8MXFP4DualScaleConfig() with no args must not raise TypeError.""" + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + cfg = DiffusionMXFP8MXFP4DualScaleConfig() + assert cfg.num_mxfp8_blocks == 0 + assert cfg.is_checkpoint_serialized is False + assert cfg.ignored_layers == [] + + +def test_mixed_config_from_config_with_num_blocks(): + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + cfg = DiffusionMXFP8MXFP4DualScaleConfig.from_config( + { + "quant_method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 5, + "is_checkpoint_serialized": True, + } + ) + assert cfg.num_mxfp8_blocks == 5 + assert cfg.is_checkpoint_serialized is True + assert cfg.ignored_layers == [] + + +def test_mixed_config_from_config_defaults(): + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + cfg = DiffusionMXFP8MXFP4DualScaleConfig.from_config({}) + assert cfg.num_mxfp8_blocks == 0 + assert cfg.is_checkpoint_serialized is False + + +def test_mixed_config_from_config_ignored_layers(): + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + cfg = DiffusionMXFP8MXFP4DualScaleConfig.from_config({"num_mxfp8_blocks": 3, "ignored_layers": ["proj_out"]}) + assert cfg.ignored_layers == ["proj_out"] + + +# --------------------------------------------------------------------------- +# build_quant_config integration +# --------------------------------------------------------------------------- + + +def test_build_quant_config_mxfp4_string(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + cfg = build_quant_config("mxfp4") + assert isinstance(cfg, DiffusionMXFP4Config) + assert cfg.get_name() == "mxfp4" + assert cfg.is_checkpoint_mxfp4_serialized is False + + +def test_build_quant_config_mxfp4_dict(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + cfg = build_quant_config({"method": "mxfp4", "is_checkpoint_mxfp4_serialized": True}) + assert isinstance(cfg, DiffusionMXFP4Config) + assert cfg.is_checkpoint_mxfp4_serialized is True + + +def test_build_quant_config_mxfp8_mxfp4_dualscale_dict(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + cfg = build_quant_config( + { + "method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 5, + "is_checkpoint_serialized": True, + } + ) + assert isinstance(cfg, DiffusionMXFP8MXFP4DualScaleConfig) + assert cfg.num_mxfp8_blocks == 5 + assert cfg.is_checkpoint_serialized is True + + +def test_build_quant_config_mxfp8_mxfp4_dualscale_warns_without_num_blocks( + monkeypatch: pytest.MonkeyPatch, +): + """build_quant_config('mxfp8_mxfp4_dualscale') must emit WARNING when + num_mxfp8_blocks is absent and default to 0 (all-MXFP4 DualScale mode). + + Uses monkeypatch instead of caplog because vllm's init_logger may configure + propagation in a way that prevents caplog from intercepting the messages. + """ + import vllm_omni.quantization.factory as factory_module + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + + warning_messages: list[str] = [] + monkeypatch.setattr( + factory_module.logger, + "warning", + lambda msg, *args, **kw: warning_messages.append(msg), + ) + + cfg = build_quant_config("mxfp8_mxfp4_dualscale") + + assert isinstance(cfg, DiffusionMXFP8MXFP4DualScaleConfig) + assert cfg.num_mxfp8_blocks == 0 + assert len(warning_messages) >= 1 + assert any("num_mxfp8_blocks" in msg for msg in warning_messages) + + +def test_build_quant_config_mxfp8_mxfp4_dualscale_info_with_num_blocks( + monkeypatch: pytest.MonkeyPatch, +): + """When num_mxfp8_blocks is provided, INFO is logged and no WARNING is emitted.""" + import vllm_omni.quantization.factory as factory_module + from vllm_omni.quantization import build_quant_config + + warning_messages: list[str] = [] + info_messages: list[str] = [] + monkeypatch.setattr( + factory_module.logger, + "warning", + lambda msg, *args, **kw: warning_messages.append(msg), + ) + monkeypatch.setattr( + factory_module.logger, + "info", + lambda msg, *args, **kw: info_messages.append(msg), + ) + + cfg = build_quant_config( + { + "method": "mxfp8_mxfp4_dualscale", + "num_mxfp8_blocks": 5, + "is_checkpoint_serialized": True, + } + ) + + assert cfg.num_mxfp8_blocks == 5 + assert len(warning_messages) == 0 + assert any("num_mxfp8_blocks" in msg for msg in info_messages) + + +# --------------------------------------------------------------------------- +# Block-index dispatch (_parse_block_idx) +# --------------------------------------------------------------------------- + + +def test_parse_block_idx_valid(): + from vllm_omni.quantization.mixed_mxfp_config import _parse_block_idx + + assert _parse_block_idx("blocks.0.attn1.to_q") == 0 + assert _parse_block_idx("blocks.5.ffn.net.0.proj") == 5 + assert _parse_block_idx("blocks.40.norm1.weight") == 40 + + +def test_parse_block_idx_non_block_prefixes(): + """Prefixes that do not start with 'blocks.N.' must return None.""" + from vllm_omni.quantization.mixed_mxfp_config import _parse_block_idx + + assert _parse_block_idx("condition_embedder.time_embedder.linear_1") is None + assert _parse_block_idx("proj_out.weight") is None + assert _parse_block_idx("model.layers.0.self_attn.q_proj") is None + assert _parse_block_idx("scale_shift_table") is None + + +# --------------------------------------------------------------------------- +# SUPPORTED_QUANTIZATION_METHODS +# --------------------------------------------------------------------------- + + +def test_supported_methods_include_mxfp4_variants(): + from vllm_omni.quantization import SUPPORTED_QUANTIZATION_METHODS + + assert "mxfp4" in SUPPORTED_QUANTIZATION_METHODS + assert "mxfp8" in SUPPORTED_QUANTIZATION_METHODS + assert "mxfp8_mxfp4_dualscale" in SUPPORTED_QUANTIZATION_METHODS diff --git a/tests/diffusion/quantization/test_mxfp4_key_remap.py b/tests/diffusion/quantization/test_mxfp4_key_remap.py new file mode 100644 index 00000000000..4f7edfded5f --- /dev/null +++ b/tests/diffusion/quantization/test_mxfp4_key_remap.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for merge_mxfp4_dualscale_checkpoint.py key-remapping helpers. + +These are pure-Python unit tests that exercise the transformation functions +without loading any actual checkpoint files or requiring NPU hardware. +""" + +import pytest + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +# --------------------------------------------------------------------------- +# SUPPORTED_MODEL_TYPES +# --------------------------------------------------------------------------- + + +def test_supported_model_types_includes_a14b(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import SUPPORTED_MODEL_TYPES + + assert "Wan2.2-T2V-A14B" in SUPPORTED_MODEL_TYPES + assert "Wan2.2-I2V-A14B" in SUPPORTED_MODEL_TYPES + + +def test_supported_model_types_excludes_ti2v_5b(): + """TI2V-5B is explicitly NOT supported under W4A4 quantization.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import SUPPORTED_MODEL_TYPES + + assert "Wan2.2-TI2V-5B" not in SUPPORTED_MODEL_TYPES + + +# --------------------------------------------------------------------------- +# _apply_rename_dict +# --------------------------------------------------------------------------- + + +def test_apply_rename_dict_self_attn_qkvo(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _apply_rename_dict + + assert _apply_rename_dict("blocks.0.self_attn.q.weight") == "blocks.0.attn1.to_q.weight" + assert _apply_rename_dict("blocks.0.self_attn.k.weight") == "blocks.0.attn1.to_k.weight" + assert _apply_rename_dict("blocks.0.self_attn.v.weight") == "blocks.0.attn1.to_v.weight" + assert _apply_rename_dict("blocks.0.self_attn.o.weight") == "blocks.0.attn1.to_out.0.weight" + + +def test_apply_rename_dict_ffn(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _apply_rename_dict + + assert _apply_rename_dict("blocks.1.ffn.0.weight") == "blocks.1.ffn.net.0.proj.weight" + assert _apply_rename_dict("blocks.1.ffn.2.weight") == "blocks.1.ffn.net.2.weight" + + +def test_apply_rename_dict_norm_order_swap(): + """norm2↔norm3 swap: quant tool uses norm1/norm3/norm2 order, + Diffusers uses norm1/norm2/norm3.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _apply_rename_dict + + assert _apply_rename_dict("blocks.0.norm2.weight") == "blocks.0.norm3.weight" + assert _apply_rename_dict("blocks.0.norm3.weight") == "blocks.0.norm2.weight" + + +def test_apply_rename_dict_cross_attn(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _apply_rename_dict + + assert _apply_rename_dict("blocks.0.cross_attn.q.weight") == "blocks.0.attn2.to_q.weight" + assert _apply_rename_dict("blocks.0.cross_attn.k.weight") == "blocks.0.attn2.to_k.weight" + assert _apply_rename_dict("blocks.0.cross_attn.v.weight") == "blocks.0.attn2.to_v.weight" + assert _apply_rename_dict("blocks.0.cross_attn.o.weight") == "blocks.0.attn2.to_out.0.weight" + + +def test_apply_rename_dict_head_and_modulation(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _apply_rename_dict + + assert _apply_rename_dict("head.head.weight") == "proj_out.weight" + + +# --------------------------------------------------------------------------- +# _strip_mxfp4_wrapper +# --------------------------------------------------------------------------- + + +def test_strip_mxfp4_wrapper_linear_weight(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _strip_mxfp4_wrapper + + assert _strip_mxfp4_wrapper("blocks.0.attn1.to_q.linear.weight") == "blocks.0.attn1.to_q.weight" + + +def test_strip_mxfp4_wrapper_linear_weight_scale(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _strip_mxfp4_wrapper + + assert _strip_mxfp4_wrapper("blocks.0.attn1.to_q.linear.weight_scale") == "blocks.0.attn1.to_q.weight_scale" + + +def test_strip_mxfp4_wrapper_linear_weight_dual_scale(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _strip_mxfp4_wrapper + + assert ( + _strip_mxfp4_wrapper("blocks.0.attn1.to_q.linear.weight_dual_scale") == "blocks.0.attn1.to_q.weight_dual_scale" + ) + + +def test_strip_mxfp4_wrapper_linear_bias(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _strip_mxfp4_wrapper + + assert _strip_mxfp4_wrapper("blocks.0.attn1.to_q.linear.bias") == "blocks.0.attn1.to_q.bias" + + +def test_strip_mxfp4_wrapper_div_mul_scale(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _strip_mxfp4_wrapper + + assert _strip_mxfp4_wrapper("blocks.0.attn1.to_q.div.mul_scale") == "blocks.0.attn1.to_q.mul_scale" + + +def test_strip_mxfp4_wrapper_noop_for_plain_weight(): + """MXFP8 / FLOAT tensors have no wrapper — must be returned unchanged.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _strip_mxfp4_wrapper + + assert _strip_mxfp4_wrapper("blocks.0.attn1.to_q.weight") == "blocks.0.attn1.to_q.weight" + assert _strip_mxfp4_wrapper("blocks.0.norm_q.weight") == "blocks.0.norm_q.weight" + assert _strip_mxfp4_wrapper("condition_embedder.time_embedder.linear_1.weight") == ( + "condition_embedder.time_embedder.linear_1.weight" + ) + + +# --------------------------------------------------------------------------- +# _classify_blocks +# --------------------------------------------------------------------------- + + +def test_classify_blocks_mixed_mxfp8_and_mxfp4(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _classify_blocks + + quant_meta = { + "blocks.0.attn1.to_q.weight": "W8A8_MXFP8", + "blocks.1.attn1.to_q.weight": "W8A8_MXFP8", + "blocks.2.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + "blocks.3.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + "condition_embedder.time_embedder.linear_1.weight": "FLOAT", + } + block_types = _classify_blocks(quant_meta) + assert block_types[0] == "mxfp8" + assert block_types[1] == "mxfp8" + assert block_types[2] == "mxfp4_dualscale" + assert block_types[3] == "mxfp4_dualscale" + # Non-block key must not produce an entry + assert None not in block_types + + +def test_classify_blocks_float_entries_skipped(): + """FLOAT-typed tensors must not contribute to block classification.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _classify_blocks + + quant_meta = { + "blocks.0.bias": "FLOAT", + "blocks.1.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + } + block_types = _classify_blocks(quant_meta) + assert 0 not in block_types + assert block_types[1] == "mxfp4_dualscale" + + +def test_classify_blocks_empty(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _classify_blocks + + assert _classify_blocks({}) == {} + + +# --------------------------------------------------------------------------- +# _detect_num_mxfp8_blocks +# --------------------------------------------------------------------------- + + +def test_detect_num_mxfp8_blocks_with_mxfp8_present(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _detect_num_mxfp8_blocks + + quant_meta = { + "blocks.0.attn1.to_q.weight": "W8A8_MXFP8", + "blocks.1.attn1.to_q.weight": "W8A8_MXFP8", + "blocks.2.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + "blocks.3.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + } + assert _detect_num_mxfp8_blocks(quant_meta) == 2 + + +def test_detect_num_mxfp8_blocks_without_mxfp8_keys(): + """When MXFP8 blocks are absent from quant_meta, the first MXFP4 block + index equals num_mxfp8_blocks (msModelSlim may omit MXFP8 markers).""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _detect_num_mxfp8_blocks + + quant_meta = { + "blocks.3.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + "blocks.4.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + } + assert _detect_num_mxfp8_blocks(quant_meta) == 3 + + +def test_detect_num_mxfp8_blocks_empty(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _detect_num_mxfp8_blocks + + assert _detect_num_mxfp8_blocks({}) == 0 diff --git a/tests/diffusion/quantization/test_mxfp8_config.py b/tests/diffusion/quantization/test_mxfp8_config.py new file mode 100644 index 00000000000..d831b5d472b --- /dev/null +++ b/tests/diffusion/quantization/test_mxfp8_config.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for MXFP8 quantization config and linear method dispatch. + +Coverage: +- DiffusionMXFP8Config.from_config roundtrips (CPU, no NPU required) +- get_quant_method dispatch (mocked platform) +- MXFPLinearMethodBase.apply() reshape skeleton (CPU) +- Weight / scale shape-transform arithmetic from process_weights_after_loading (CPU) +- build_quant_config integration +- MXFP8_QUANT_CONFIG structure as the auto-detection contract +""" + +import pytest +import torch +from pytest_mock import MockerFixture +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod + +from vllm_omni.platforms import current_omni_platform +from vllm_omni.quantization import build_quant_config + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + +npu_available = pytest.mark.skipif(not current_omni_platform.is_npu(), reason="NPU platform not available") + + +# --------------------------------------------------------------------------- +# DiffusionMXFP8Config — from_config roundtrips +# --------------------------------------------------------------------------- + + +def test_mxfp8_config_get_name(): + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + assert DiffusionMXFP8Config.get_name() == "mxfp8" + + +def test_mxfp8_config_from_config_defaults(): + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + cfg = DiffusionMXFP8Config.from_config({}) + assert cfg.is_checkpoint_mxfp8_serialized is False + assert cfg.ignored_layers == [] + + +def test_mxfp8_config_from_config_serialized(): + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + cfg = DiffusionMXFP8Config.from_config({"is_checkpoint_mxfp8_serialized": True}) + assert cfg.is_checkpoint_mxfp8_serialized is True + + +def test_mxfp8_config_from_config_ignored_layers(): + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + cfg = DiffusionMXFP8Config.from_config({"ignored_layers": ["proj_out"]}) + assert cfg.ignored_layers == ["proj_out"] + + +def test_mxfp8_config_from_config_modules_to_not_convert_fallback(): + """modules_to_not_convert must be accepted as an alias for ignored_layers.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + cfg = DiffusionMXFP8Config.from_config({"modules_to_not_convert": ["proj_out"]}) + assert cfg.ignored_layers == ["proj_out"] + + +# --------------------------------------------------------------------------- +# build_quant_config integration +# --------------------------------------------------------------------------- + + +def test_build_quant_config_mxfp8_string(): + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + cfg = build_quant_config("mxfp8") + assert isinstance(cfg, DiffusionMXFP8Config) + assert cfg.get_name() == "mxfp8" + assert cfg.is_checkpoint_mxfp8_serialized is False + + +def test_build_quant_config_mxfp8_dict(): + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + cfg = build_quant_config({"method": "mxfp8", "is_checkpoint_mxfp8_serialized": True}) + assert isinstance(cfg, DiffusionMXFP8Config) + assert cfg.is_checkpoint_mxfp8_serialized is True + + +def test_build_quant_config_mxfp8_config_json_format(): + """Verify that the exact quantization_config injected by merge_mxfp8_checkpoint.py + is accepted by build_quant_config and selects the offline (serialized) path. + + This is the critical auto-detection contract: TransformerConfig.from_dict() + reads quant_method + is_checkpoint_mxfp8_serialized to pick NPUMxfp8LinearMethod. + """ + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import MXFP8_QUANT_CONFIG + + cfg = build_quant_config(MXFP8_QUANT_CONFIG) + assert isinstance(cfg, DiffusionMXFP8Config) + assert cfg.is_checkpoint_mxfp8_serialized is True + + +def test_mxfp8_quant_config_structure(): + """MXFP8_QUANT_CONFIG must contain exactly the keys that auto-detection reads.""" + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import MXFP8_QUANT_CONFIG + + assert MXFP8_QUANT_CONFIG.get("quant_method") == "mxfp8" + assert MXFP8_QUANT_CONFIG.get("is_checkpoint_mxfp8_serialized") is True + + +# --------------------------------------------------------------------------- +# get_quant_method dispatch +# --------------------------------------------------------------------------- + + +def test_get_quant_method_npu_offline(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch): + """Offline (serialized) path on NPU must return NPUMxfp8LinearMethod.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config, NPUMxfp8LinearMethod + + config = DiffusionMXFP8Config(is_checkpoint_mxfp8_serialized=True) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + + method = config.get_quant_method(layer, "blocks.0.attn1.to_q") + assert isinstance(method, NPUMxfp8LinearMethod) + + +def test_get_quant_method_npu_online(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch): + """Online (BF16 checkpoint) path on NPU must return NPUMxfp8OnlineLinearMethod.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config, NPUMxfp8OnlineLinearMethod + + config = DiffusionMXFP8Config(is_checkpoint_mxfp8_serialized=False) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + + method = config.get_quant_method(layer, "blocks.0.attn1.to_q") + assert isinstance(method, NPUMxfp8OnlineLinearMethod) + + +def test_get_quant_method_non_npu_raises(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch): + """Non-NPU platform must raise NotImplementedError (CUDA not yet supported).""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + config = DiffusionMXFP8Config() + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: False) + monkeypatch.setattr(current_omni_platform, "is_cuda", lambda: False) + + with pytest.raises(NotImplementedError): + config.get_quant_method(layer, "blocks.0.attn1.to_q") + + +def test_get_quant_method_ignored_layer(mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch): + """A prefix in ignored_layers must bypass quantization → UnquantizedLinearMethod.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + config = DiffusionMXFP8Config(ignored_layers=["proj_out"]) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + + method = config.get_quant_method(layer, "proj_out") + assert isinstance(method, UnquantizedLinearMethod) + + +def test_get_quant_method_non_linear_returns_none(monkeypatch: pytest.MonkeyPatch): + """Non-LinearBase layers (norms, embeddings) must get None → no quantization.""" + from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config + + config = DiffusionMXFP8Config() + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + + norm_layer = torch.nn.LayerNorm(64) + assert config.get_quant_method(norm_layer, "blocks.0.norm1") is None + + +# --------------------------------------------------------------------------- +# MXFPLinearMethodBase.apply() — reshape skeleton (CPU, no NPU) +# --------------------------------------------------------------------------- + + +def test_apply_reshape_skeleton(): + """apply() must flatten batch dims → _apply_inner → restore original leading dims.""" + from vllm_omni.quantization.mxfp8_config import MXFPLinearMethodBase + + OUT_FEATURES = 4 + + class _StubMethod(MXFPLinearMethodBase): + def create_weights( + self, + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ): + pass + + def _quantize_activation(self, x): + return x, None + + def _quant_matmul(self, x_q, x_scale, layer, bias, ori_dtype): + return torch.zeros(x_q.shape[0], OUT_FEATURES, dtype=ori_dtype) + + method = _StubMethod() + x = torch.randn(2, 3, 8) # (batch=2, seq=3, K=8) + out = method.apply(None, x) + assert out.shape == (2, 3, OUT_FEATURES) + + +def test_apply_reshape_with_bias(): + """apply() must pass bias through to _apply_inner unchanged.""" + from vllm_omni.quantization.mxfp8_config import MXFPLinearMethodBase + + received_bias = [] + + class _StubMethod(MXFPLinearMethodBase): + def create_weights( + self, + layer, + input_size_per_partition, + output_partition_sizes, + input_size, + output_size, + params_dtype, + **extra_weight_attrs, + ): + pass + + def _quantize_activation(self, x): + return x, None + + def _quant_matmul(self, x_q, x_scale, layer, bias, ori_dtype): + received_bias.append(bias) + return torch.zeros(x_q.shape[0], 4, dtype=ori_dtype) + + method = _StubMethod() + bias = torch.zeros(4) + method.apply(None, torch.randn(2, 8), bias=bias) + assert received_bias[0] is bias + + +# --------------------------------------------------------------------------- +# process_weights_after_loading shape arithmetic (pure torch, no NPU ops) +# +# These tests replicate the CPU-safe portions of process_weights_after_loading +# to guard the key layout contract: (N,K) weight → (K,N) and (N,S) scale → +# (S/2,N,2). They do NOT call NPU ops; they test only the math. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "n, k", + [(64, 128), (32, 64), (16, 96)], + ids=["64x128", "32x64", "16x96"], +) +def test_weight_transpose_contract(n: int, k: int): + """Weight must be transposed from (N, K) to (K, N) and be contiguous.""" + w = torch.zeros(n, k, dtype=torch.uint8) + w = w.transpose(0, 1).contiguous() + assert w.shape == (k, n) + assert w.is_contiguous() + + +@pytest.mark.parametrize( + "n, k_groups, expected_groups", + [ + (64, 4, 2), # even K_groups — no padding + (64, 3, 2), # odd K_groups — padded to 4 + (32, 1, 1), # odd K_groups — padded to 2 + ], + ids=["even", "odd-to-4", "odd-to-2"], +) +def test_weight_scale_reshape_contract(n: int, k_groups: int, expected_groups: int): + """Scale must be reshaped from (N, K_groups) to (K_groups_even//2, N, 2). + + Odd K_groups must be padded to even before the reshape. + """ + s = torch.zeros(n, k_groups, dtype=torch.uint8) + if k_groups % 2 == 1: + s = torch.cat([s, torch.zeros(n, 1, dtype=s.dtype)], dim=1) + k_groups += 1 + s = s.reshape(n, k_groups // 2, 2).transpose(0, 1).contiguous() + assert s.shape == (expected_groups, n, 2) + assert s.is_contiguous() + + +def test_num_groups_formula(): + """K_groups formula: ceil(K / 32) — spot-check boundary values.""" + assert (31 + 31) // 32 == 1 # K=31 → 1 group + assert (32 + 31) // 32 == 1 # K=32 → 1 group + assert (33 + 31) // 32 == 2 # K=33 → 2 groups + assert (128 + 31) // 32 == 4 # K=128 → 4 groups (even) + assert (96 + 31) // 32 == 3 # K=96 → 3 groups (odd → needs padding) + + +# --------------------------------------------------------------------------- +# SUPPORTED_QUANTIZATION_METHODS +# --------------------------------------------------------------------------- + + +def test_supported_methods_include_mxfp8(): + from vllm_omni.quantization import SUPPORTED_QUANTIZATION_METHODS + + assert "mxfp8" in SUPPORTED_QUANTIZATION_METHODS diff --git a/tests/diffusion/quantization/test_mxfp8_key_remap.py b/tests/diffusion/quantization/test_mxfp8_key_remap.py new file mode 100644 index 00000000000..f6eea04a07f --- /dev/null +++ b/tests/diffusion/quantization/test_mxfp8_key_remap.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for merge_mxfp8_checkpoint.py key-remapping helpers and model metadata. + +These are pure-Python unit tests that exercise the transformation functions +without loading any actual checkpoint files or requiring NPU hardware. + +Key contracts verified: +- SUPPORTED_MODEL_TYPES includes Wan2.2-TI2V-5B (MXFP8 supports it; MXFP4 does not) +- MXFP8_QUANT_CONFIG structure matches what TransformerConfig.from_dict() reads +- _remap_keys correctly translates msModelSlim naming → Diffusers naming +- _get_transformer_dirs routes cascade vs single-transformer models +- _get_quant_subdir maps high/low noise subdirs for cascade models +""" + +import pathlib + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +# --------------------------------------------------------------------------- +# SUPPORTED_MODEL_TYPES and CASCADE_MODEL_TYPES +# --------------------------------------------------------------------------- + + +def test_supported_model_types_includes_all_wan22(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import SUPPORTED_MODEL_TYPES + + assert "Wan2.2-T2V-A14B" in SUPPORTED_MODEL_TYPES + assert "Wan2.2-I2V-A14B" in SUPPORTED_MODEL_TYPES + + +def test_supported_model_types_includes_ti2v_5b(): + """MXFP8 supports TI2V-5B (contrast: MXFP4 explicitly excludes it).""" + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import SUPPORTED_MODEL_TYPES + + assert "Wan2.2-TI2V-5B" in SUPPORTED_MODEL_TYPES + + +def test_cascade_model_types_excludes_ti2v(): + """TI2V-5B is a single-transformer model — must not be in CASCADE_MODEL_TYPES.""" + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import CASCADE_MODEL_TYPES + + assert "Wan2.2-TI2V-5B" not in CASCADE_MODEL_TYPES + assert "Wan2.2-T2V-A14B" in CASCADE_MODEL_TYPES + assert "Wan2.2-I2V-A14B" in CASCADE_MODEL_TYPES + + +# --------------------------------------------------------------------------- +# MXFP8_QUANT_CONFIG — auto-detection contract +# --------------------------------------------------------------------------- + + +def test_mxfp8_quant_config_has_required_keys(): + """MXFP8_QUANT_CONFIG must carry exactly the keys that auto-detection reads: + quant_method (selects DiffusionMXFP8Config) and is_checkpoint_mxfp8_serialized + (selects NPUMxfp8LinearMethod over the online path).""" + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import MXFP8_QUANT_CONFIG + + assert MXFP8_QUANT_CONFIG["quant_method"] == "mxfp8" + assert MXFP8_QUANT_CONFIG["is_checkpoint_mxfp8_serialized"] is True + + +# --------------------------------------------------------------------------- +# _get_transformer_dirs +# --------------------------------------------------------------------------- + + +def test_get_transformer_dirs_cascade(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _get_transformer_dirs + + assert _get_transformer_dirs("Wan2.2-T2V-A14B") == ["transformer", "transformer_2"] + assert _get_transformer_dirs("Wan2.2-I2V-A14B") == ["transformer", "transformer_2"] + + +def test_get_transformer_dirs_single(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _get_transformer_dirs + + assert _get_transformer_dirs("Wan2.2-TI2V-5B") == ["transformer"] + + +# --------------------------------------------------------------------------- +# _get_quant_subdir +# --------------------------------------------------------------------------- + + +def test_get_quant_subdir_cascade_high_noise(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _get_quant_subdir + + base = pathlib.Path("/quant") + result = _get_quant_subdir("Wan2.2-T2V-A14B", base, "transformer") + assert result == base / "high_noise_model" + + +def test_get_quant_subdir_cascade_low_noise(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _get_quant_subdir + + base = pathlib.Path("/quant") + result = _get_quant_subdir("Wan2.2-T2V-A14B", base, "transformer_2") + assert result == base / "low_noise_model" + + +def test_get_quant_subdir_non_cascade(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _get_quant_subdir + + base = pathlib.Path("/quant") + result = _get_quant_subdir("Wan2.2-TI2V-5B", base, "transformer") + assert result == base + + +# --------------------------------------------------------------------------- +# _remap_keys — msModelSlim naming → Diffusers naming +# --------------------------------------------------------------------------- + + +def test_remap_keys_self_attn_q(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + state = {"blocks.0.self_attn.q.weight": torch.zeros(1)} + meta = {"blocks.0.self_attn.q.weight": "W8A8_MXFP8"} + new_state, new_meta = _remap_keys(state, meta) + assert "blocks.0.attn1.to_q.weight" in new_state + assert new_meta.get("blocks.0.attn1.to_q.weight") == "W8A8_MXFP8" + + +def test_remap_keys_self_attn_all_heads(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + pairs = { + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + } + for src_part, dst_part in pairs.items(): + src_key = f"blocks.0.{src_part}.weight" + state = {src_key: torch.zeros(1)} + new_state, _ = _remap_keys(state, {}) + expected = f"blocks.0.{dst_part}.weight" + assert expected in new_state, f"{src_key} → expected {expected}, got {list(new_state)}" + + +def test_remap_keys_ffn(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + state = { + "blocks.1.ffn.0.weight": torch.zeros(1), + "blocks.1.ffn.2.weight": torch.zeros(1), + } + new_state, _ = _remap_keys(state, {}) + assert "blocks.1.ffn.net.0.proj.weight" in new_state + assert "blocks.1.ffn.net.2.weight" in new_state + + +def test_remap_keys_norm_order_swap(): + """norm2↔norm3 swap: msModelSlim uses norm1/norm3/norm2 order, + Diffusers uses norm1/norm2/norm3.""" + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + state = { + "blocks.0.norm2.weight": torch.zeros(1), + "blocks.0.norm3.weight": torch.zeros(1), + } + new_state, _ = _remap_keys(state, {}) + # norm2 → norm3 and norm3 → norm2 + assert "blocks.0.norm3.weight" in new_state + assert "blocks.0.norm2.weight" in new_state + # Both must be present (swap, not collapse) + assert len([k for k in new_state if "norm" in k and "norm_q" not in k and "norm_k" not in k]) == 2 + + +def test_remap_keys_head(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + state = {"head.head.weight": torch.zeros(1)} + new_state, _ = _remap_keys(state, {}) + assert "proj_out.weight" in new_state + + +def test_remap_keys_cross_attn(): + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + state = {"blocks.0.cross_attn.q.weight": torch.zeros(1)} + new_state, _ = _remap_keys(state, {}) + assert "blocks.0.attn2.to_q.weight" in new_state + + +def test_remap_keys_meta_only_mapped_for_existing_state_keys(): + """quant_meta entries are only emitted for keys present in state_dict.""" + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + state = {"blocks.0.self_attn.q.weight": torch.zeros(1)} + # meta has an extra key not in state_dict + meta = { + "blocks.0.self_attn.q.weight": "W8A8_MXFP8", + "blocks.0.self_attn.q.weight_scale": "W8A8_MXFP8", + } + _, new_meta = _remap_keys(state, meta) + assert "blocks.0.attn1.to_q.weight" in new_meta + # weight_scale was in meta but not state_dict → must be absent + assert "blocks.0.attn1.to_q.weight_scale" not in new_meta + + +def test_remap_keys_preserves_tensors(): + """Tensor values must survive the key rename unchanged.""" + from vllm_omni.quantization.tools.merge_mxfp8_checkpoint import _remap_keys + + t = torch.randn(4, 8) + state = {"blocks.0.self_attn.q.weight": t} + new_state, _ = _remap_keys(state, {}) + assert torch.equal(new_state["blocks.0.attn1.to_q.weight"], t) diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py index c00718ac966..fa51324b61b 100644 --- a/vllm_omni/quantization/factory.py +++ b/vllm_omni/quantization/factory.py @@ -56,7 +56,31 @@ def _build_mxfp4(**kw: Any) -> QuantizationConfig: def _build_mxfp8_mxfp4_dualscale(**kw: Any) -> QuantizationConfig: - """Lazy import for MXFP8 (early blocks) + MXFP4 dual-scale (later blocks) config (NPU only).""" + """Lazy import for MXFP8 (early blocks) + MXFP4 dual-scale (later blocks) config (NPU only). + + This method is checkpoint-topology-dependent: num_mxfp8_blocks is normally + injected into transformer/config.json by merge_mxfp4_dualscale_checkpoint.py + and auto-detected from there during offline checkpoint loading. + If invoked without num_mxfp8_blocks (e.g. via --quantization mxfp8_mxfp4_dualscale + on a BF16 checkpoint), num_mxfp8_blocks defaults to 0 and all blocks fall through + to the MXFP4 DualScale online path — the MXFP8 branch is never selected. + """ + if "num_mxfp8_blocks" not in kw: + logger.warning( + "'mxfp8_mxfp4_dualscale' was requested without num_mxfp8_blocks. " + "Defaulting to num_mxfp8_blocks=0: all transformer blocks will use " + "MXFP4 DualScale online quantization and no MXFP8 blocks will be applied. " + "This mode is not recommended for online (BF16 checkpoint) use. " + "For the intended mixed MXFP8+MXFP4 mode, use a pre-quantized checkpoint " + "produced by merge_mxfp4_dualscale_checkpoint.py and omit --quantization " + "to let vllm-omni auto-detect num_mxfp8_blocks from transformer/config.json." + ) + else: + logger.info( + "Building mxfp8_mxfp4_dualscale config: num_mxfp8_blocks=%d, is_checkpoint_serialized=%s", + kw["num_mxfp8_blocks"], + kw.get("is_checkpoint_serialized", False), + ) from .mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig return DiffusionMXFP8MXFP4DualScaleConfig(**kw) diff --git a/vllm_omni/quantization/tools/__init__.py b/vllm_omni/quantization/tools/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/quantization/tools/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py index 37faf67b0bd..efe682a8945 100644 --- a/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py +++ b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py @@ -36,7 +36,9 @@ Supported model types: - Wan2.2-T2V-A14B (MoE cascade: transformer + transformer_2) - Wan2.2-I2V-A14B (MoE cascade: transformer + transformer_2) - - Wan2.2-TI2V-5B (single transformer) + +Note: Wan2.2-TI2V-5B is NOT supported. Its smaller parameter count causes +unacceptable accuracy loss under W4A4 quantization. Use MXFP8 for TI2V-5B. Usage: python merge_mxfp4_dualscale_checkpoint.py \\ @@ -105,7 +107,7 @@ "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", } -SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] +SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B"] CASCADE_MODEL_TYPES = {"Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B"} # Suffixes that appear inside .linear.* wrapper for MXFP4 tensors. From b58e0a61b206a77ab78f4679806c3548b5f5ed4e Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Sat, 16 May 2026 10:07:18 +0800 Subject: [PATCH 4/8] mxfp8_mxfp4_dualscale -> mxfp4_dualscale Signed-off-by: hyh_hh --- docs/user_guide/quantization/mxfp4.md | 275 +++++++++----- docs/user_guide/quantization/overview.md | 2 +- .../image_to_video/image_to_video.py | 4 +- .../text_to_video/text_to_video.py | 4 +- .../wan2_2/test_create_transformer_quant.py | 98 ++--- .../quantization/test_mxfp4_config.py | 340 ++++++++++++------ .../quantization/test_mxfp4_key_remap.py | 230 ++++++++++-- .../models/wan2_2/pipeline_wan2_2.py | 19 +- vllm_omni/quantization/factory.py | 41 +-- vllm_omni/quantization/mixed_mxfp_config.py | 149 -------- vllm_omni/quantization/mxfp4_config.py | 145 ++++++++ .../tools/merge_mxfp4_dualscale_checkpoint.py | 214 +++++------ 12 files changed, 958 insertions(+), 563 deletions(-) delete mode 100644 vllm_omni/quantization/mixed_mxfp_config.py diff --git a/docs/user_guide/quantization/mxfp4.md b/docs/user_guide/quantization/mxfp4.md index 2b3f20bfd1a..608026b3044 100644 --- a/docs/user_guide/quantization/mxfp4.md +++ b/docs/user_guide/quantization/mxfp4.md @@ -6,23 +6,31 @@ W4A4 MXFP4 (Microscaling FP4) quantizes both weights and activations to FP4 (`float4_e2m1fn_x2`, packed 2 values per byte) using the OCP MX format: groups of 32 K-dimension elements share a single `float8_e8m0fnu` exponent scale. -This method supports two modes that differ significantly in scale structure and -checkpoint format: - -| Mode | Scale structure | Description | -|------|----------------|-------------| -| **Online** | Single-scale (per-32 fine only) | BF16 weights are quantized to MXFP4 at load time — no pre-processing needed | -| **Offline** | Dual-scale (fine per-32 + coarse per-512 + per-channel smooth pre-scale) | msModelSlim-exported MXFP4 DualScale weights converted to diffusers format via preprocessing script — all scale tensors are loaded directly from the checkpoint | - -!!! warning "Online ≠ Offline" - Online mode uses a **single-scale** (`NPUMxfp4OnlineLinearMethod`): one - `float8_e8m0fnu` exponent per 32 K elements, computed on the fly from the - BF16 weight. Offline mode uses a **dual-scale** (`NPUMxfp4DualScaleLinearMethod`): a - fine scale (per-32 K), a coarse scale (per-512 K), and a per-input-channel - smooth pre-scale (`mul_scale`) produced by calibration. The two levels and - the smooth pre-scale are all stored in the checkpoint; loading an offline - checkpoint with the online method (or vice versa) will produce incorrect - results. +vLLM-Omni provides two quantization methods with different scale structures: + +| Method | Scale structure | Mode | Use case | +|--------|----------------|------|----------| +| `mxfp4` | Single-scale (per-32 fine only) | Online only | Quick accuracy baseline; no checkpoint prep needed | +| `mxfp4_dualscale` | Dual-scale (fine per-32 + coarse per-512 + per-channel `mul_scale`) | Online + Offline | Production; better accuracy; offline recommended | + +!!! tip "Recommended: `mxfp4_dualscale` offline" + For production deployments, use the `mxfp4_dualscale` offline mode with a + pre-quantized checkpoint produced by msModelSlim. Offline checkpoints load + calibrated `mul_scale` tensors from disk, providing measurably better accuracy + than any online method. The one-time preprocessing cost amortises across all + subsequent inference runs. + + Use `mxfp4` online only for quick experimentation where preprocessing time + is not acceptable and accuracy loss is tolerable. + +!!! warning "Online single-scale ≠ Offline dual-scale" + `mxfp4_dualscale` offline mode uses `NPUMxfp4DualScaleLinearMethod`: + fine scale (per-32 K), coarse scale (per-512 K), and per-input-channel + `mul_scale` from calibration — all loaded from the checkpoint. + `mxfp4_dualscale` online mode uses `NPUMxfp4DualScaleOnlineLinearMethod`: + dual-level scales computed on the fly from BF16 weights; no calibration + `mul_scale` is available. Loading an offline checkpoint with the online + method (or vice versa) will produce incorrect results or shape errors. ## Hardware Support @@ -41,18 +49,30 @@ Legend: `✅` supported, `❌` unsupported, `⭕` not verified in this guide. ### Diffusion Model (Wan2.2) -| Model | Mode | Notes | -|-------|------|-------| -| Wan2.2-T2V-A14B | Online + Offline | MoE cascade; quantizes two transformers (`transformer` + `transformer_2`); offline uses mixed MXFP8 (early blocks) + MXFP4 DualScale (remaining blocks) | -| Wan2.2-I2V-A14B | Online + Offline | MoE cascade; same mixed-precision scheme as T2V-A14B | -| Wan2.2-TI2V-5B | ❌ Not supported | Parameter count too small; W4A4 quantization causes unacceptable accuracy loss | +| Model | Online | Offline | Notes | +|-------|--------|---------|-------| +| Wan2.2-T2V-A14B | `mxfp4` / `mxfp4_dualscale` | `mxfp4_dualscale` | MoE cascade (`transformer` + `transformer_2`); both transformers quantized with the same config | +| Wan2.2-I2V-A14B | `mxfp4` / `mxfp4_dualscale` | `mxfp4_dualscale` | MoE cascade; same scheme as T2V-A14B | +| Wan2.2-TI2V-5B | ❌ | ❌ | Parameter count too small; W4A4 causes unacceptable accuracy loss | -!!! note "Mixed MXFP8 + MXFP4 for cascade models" - For the A14B cascade models, the offline checkpoint uses - `quant_method: mxfp8_mxfp4_dualscale`: the first `num_mxfp8_blocks` - transformer blocks are stored as MXFP8 (W8A8), and the remaining blocks as - MXFP4 DualScale (W4A4). The split is recorded in the injected - `quantization_config` and is transparent to the serving command. +The choice between `mxfp4` and `mxfp4_dualscale` in **online mode** is about +quantization quality, not model compatibility — both work on cascade (A14B) and +single-transformer models alike, the same as `mxfp8` online: + +- `mxfp4`: single-scale, lower overhead, simpler compute, online only +- `mxfp4_dualscale`: dual-scale + optional BF16 fallback, better accuracy, online **and** offline + +**Offline** checkpoints for A14B are always in `mxfp4_dualscale` format (produced +by the merge script); there is no offline `mxfp4` single-scale format. + +!!! note "Per-layer BF16 fallback in offline cascade models" + The A14B offline checkpoint uses `quant_method: mxfp4_dualscale`. Most + linear layers are stored as W4A4 MXFP4 DualScale; precision-sensitive layers + retain their original BF16 weights and are listed in `ignored_layers` inside + each transformer's `config.json`. The two transformers may have different + `ignored_layers` sets — the pipeline reads each transformer's own `config.json` + and rebuilds the config locally when they differ, so routing is always + per-transformer-accurate. !!! warning "TI2V-5B not supported" Wan2.2-TI2V-5B is excluded from W4A4 quantization. Its smaller parameter @@ -61,44 +81,89 @@ Legend: `✅` supported, `❌` unsupported, `⭕` not verified in this guide. ## Configuration -### Online Mode +### `mxfp4` — Single-Scale Online Mode Online mode requires no pre-processing. vLLM-Omni quantizes BF16 weights to MXFP4 at load time using `npu_dynamic_mx_quant`. A single block scale (`float8_e8m0fnu`, one per 32 K elements) is computed on the fly; no -calibration `mul_scale` is available. - -Python API: +calibration `mul_scale` is available. Applies equally to single-transformer +and cascade (A14B) models — both transformers in a cascade receive the same +quantization config automatically. ```python from vllm_omni import Omni from vllm_omni.inputs.data import OmniDiffusionSamplingParams omni = Omni(model="", quantization="mxfp4") - outputs = omni.generate( "A cat sitting on a windowsill", OmniDiffusionSamplingParams(num_inference_steps=50), ) ``` -CLI: - ```bash +# Single-transformer or cascade model — same command python text_to_video.py --model --quantization mxfp4 # Online serving vllm serve --omni --quantization mxfp4 ``` -### Offline Mode (DualScale) +### `mxfp4_dualscale` — DualScale Online Mode + +Online DualScale mode computes both fine and coarse scales on the fly from BF16 +weights using `npu_dynamic_dual_level_mx_quant`. Applies equally to +single-transformer and cascade (A14B) models. Compared to `mxfp4` online, +DualScale provides better quantization accuracy at higher compute cost. + +The default configuration keeps the leading 5 transformer blocks in BF16 +(`num_bf16_fallback_layers=5`). Accuracy evaluation on Wan2.2-A14B shows this +is sufficient to meet quality requirements and is the recommended setting. + +```python +omni = Omni(model="", quantization="mxfp4_dualscale") +``` + +```bash +python text_to_video.py --model --quantization mxfp4_dualscale +``` + +If accuracy debugging identifies additional precision-sensitive layers, they +can be pinned to BF16 via the Python API: + +```python +omni = Omni( + model="", + quantization={ + "method": "mxfp4_dualscale", + "ignored_layers": ["blocks.10.attn1.to_q"], # explicit per-layer override + }, +) +``` + +BF16 fallback routing in online mode applies two rules in priority order: + +1. **`ignored_layers`** (explicit per-layer override): any layer whose prefix + matches is kept in BF16 regardless of block index. +2. **`num_bf16_fallback_layers`** (coarse leading-block rule): the first N + transformer blocks (`blocks.0` … `blocks.N-1`) fall back to BF16. Defaults + to `5` (recommended). Layers outside `blocks.N.*` + (e.g. `condition_embedder`) are always quantized. + +### `mxfp4_dualscale` — DualScale Offline Mode (Recommended) Offline mode loads a pre-quantized DualScale checkpoint from msModelSlim. A preprocessing step converts the raw quantized output to the diffusers format -expected by vLLM-Omni and injects the quantization config into +expected by vLLM-Omni and injects the quantization config into each `transformer/config.json` so that vLLM-Omni auto-detects the offline path without a `--quantization` flag. +BF16 fallback layers may be interleaved anywhere in the transformer — they are +not restricted to leading blocks. The merge script detects them from +`quant_model_description.json` and writes their prefixes into `ignored_layers` +inside `config.json`. At runtime, each layer's prefix is matched against +`ignored_layers` to decide BF16 vs. MXFP4 DualScale. + #### Checkpoint tensor layout Each quantized linear layer stores four tensors: @@ -110,6 +175,9 @@ Each quantized linear layer stores four tensors: | `weight_dual_scale` | `(N, K//512, 1)` | float32 | Coarse block scale | | `mul_scale` | `(K,)` | float32 | Per-input-channel smooth pre-scale (from calibration) | +BF16 fallback layers have no quantization tensors; only the original `weight` +(and optional `bias`) are present, loaded directly from the base checkpoint. + #### Step 1 — Quantize with msModelSlim ```bash @@ -122,11 +190,12 @@ msmodelslim quant \ --trust_remote_code True ``` -After this step, `--save_path` contains the raw quantized safetensors files, +After this step, `--save_path` contains raw quantized safetensors files, scale files, and a metadata JSON (`quant_model_description*.json`). For cascade MoE models (T2V-A14B, I2V-A14B), msModelSlim outputs two -subdirectories: `high_noise_model/` and `low_noise_model/`. +subdirectories: `high_noise_model/` (transformer) and `low_noise_model/` +(transformer_2). #### Step 2 — Preprocess with merge_mxfp4_dualscale_checkpoint.py @@ -134,15 +203,16 @@ The script (`vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py`): 1. Copies the original diffusers model to `--output-path` (VAE, text encoder, scheduler, etc. are preserved). -2. Remaps tensor names from msModelSlim convention to diffusers convention. -3. Saves the converted weights, fine/coarse scales, and `mul_scale` as - `diffusion_pytorch_model.safetensors`. -4. Copies the original `transformer/config.json` and injects - `quantization_config` so that vLLM-Omni auto-detects offline MXFP4 +2. Remaps tensor names from msModelSlim convention to diffusers convention and + strips `.linear.` / `.div.` wrappers added by the quantization tool. +3. Overlays MXFP4 tensors (weight, fine/coarse scales, `mul_scale`) onto the + BF16 base checkpoint. Non-quantized layers keep their original BF16 weights. +4. Detects all linear layers that remain in BF16 and writes their prefixes into + `ignored_layers` in `config.json`. +5. Injects `quantization_config` so vLLM-Omni auto-detects offline MXFP4 DualScale. -For cascade MoE models, steps 2–4 run separately for `high_noise_model/` → -`transformer/` and `low_noise_model/` → `transformer_2/`. +For cascade MoE models, steps 2–5 run separately for each transformer. ```bash python vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py \ @@ -162,20 +232,29 @@ python vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py \ The script outputs a complete diffusers model directory at `--output-path`, with each transformer subfolder containing: -- `diffusion_pytorch_model.safetensors` — converted FP4 weights, fine/coarse scales, and `mul_scale` +- `diffusion_pytorch_model.safetensors` — MXFP4 weights + scale tensors, with BF16 fallback layers from the base checkpoint - `config.json` — original transformer config with `quantization_config` injected -- `quant_model_description.json` — renamed quantization metadata (reference only) +- `quant_model_description.json` — quantization metadata (reference only) The `quantization_config` injected into `config.json` for each transformer: ```json { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 5, - "is_checkpoint_serialized": true + "quant_method": "mxfp4_dualscale", + "is_checkpoint_serialized": true, + "ignored_layers": [ + "blocks.0.attn1.to_qkv", + "blocks.0.attn1.to_out", + "proj_out" + ] } ``` +`ignored_layers` lists every linear layer that retains its original BF16 weight, +using vllm-omni model parameter names (QKV-fused, FFN underscored, `to_out` +unindexed). The exact entries are determined by the quantization tool (msModelSlim) +and may differ between `transformer` and `transformer_2` in a cascade model. + #### Step 3 — Serve ```bash @@ -185,8 +264,6 @@ python text_to_video.py --model /path/to/Wan2.2-T2V-A14B-MXFP4-DualScale vllm serve /path/to/Wan2.2-T2V-A14B-MXFP4-DualScale --omni ``` -Python API: - ```python omni = Omni(model="/path/to/Wan2.2-T2V-A14B-MXFP4-DualScale") ``` @@ -194,54 +271,72 @@ omni = Omni(model="/path/to/Wan2.2-T2V-A14B-MXFP4-DualScale") !!! note No `--quantization` flag is needed for offline mode. The preprocessing script injects `quantization_config` into each `transformer/config.json`, - which vLLM-Omni reads automatically to activate the offline MXFP4 - DualScale method. + which vLLM-Omni reads automatically to activate the correct offline path. ## Parameters -### Online Mode (`mxfp4`) +### `mxfp4` (single-scale, online only) | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `method` | str | — | Must be `"mxfp4"` | -| `is_checkpoint_mxfp4_serialized` | bool | `False` | Set `True` to load a single-scale offline checkpoint; leave `False` (default) for online BF16-to-FP4 quantization | -| `ignored_layers` | list[str] | `[]` | Layer name substrings to keep in BF16 | +| `method` | str | — | `"mxfp4"` | +| `ignored_layers` | list[str] | `[]` | Layer prefixes to keep in BF16 | -### Offline DualScale Mode (`mxfp8_mxfp4_dualscale`) +### `mxfp4_dualscale` (dual-scale, online + offline) | Parameter | Type | Default | Description | |-----------|------|---------|-------------| -| `method` | str | — | Must be `"mxfp8_mxfp4_dualscale"` | -| `num_mxfp8_blocks` | int | `0` | Number of leading transformer blocks kept as MXFP8; remaining blocks use MXFP4 DualScale | +| `method` | str | — | `"mxfp4_dualscale"` | | `is_checkpoint_serialized` | bool | `False` | `True` for offline DualScale checkpoints; auto-set from `config.json` when using the preprocessing script | -| `ignored_layers` | list[str] | `[]` | Layer name substrings to keep in BF16 | +| `ignored_layers` | list[str] | `[]` | Layer prefixes to keep in BF16. **Works in both modes**: offline — populated by the merge script for interleaved sensitive layers; online — user-supplied for explicit per-layer precision override | +| `num_bf16_fallback_layers` | int | `5` | **Online mode only**: leading N transformer blocks (`blocks.0` … `blocks.N-1`) kept in BF16. Applied after `ignored_layers`; ignored in offline mode. Default of `5` is the evaluated recommended value for Wan2.2-A14B | + +#### BF16 fallback priority (online mode) + +``` +for each linear layer: + if prefix in ignored_layers → BF16 (explicit override, highest priority) + elif block_idx < num_bf16_fallback_layers → BF16 (coarse leading-block rule) + else → MXFP4 DualScale online +``` + +Layers outside `blocks.N.*` (e.g. `condition_embedder.*`) are always quantized +unless they appear in `ignored_layers`. ## Validation and Notes -1. **Online mode** quantizes BF16 weights at load time using - `npu_dynamic_mx_quant` (single-scale). This adds a one-time overhead on the - first load but requires no checkpoint preparation. No calibration - `mul_scale` is available — all output partitions receive an identity - pre-scale. - -2. **Offline DualScale mode** loads four tensors per quantized layer: the FP4 - packed weight, a fine block scale (`uint8` interpreted as - `float8_e8m0fnu`), a coarse block scale (`float32`), and a per-input-channel - smooth pre-scale (`mul_scale`, `float32`). The `mul_scale` is derived from - calibration and applied to the activation before dual-level quantization - (`npu_dynamic_dual_level_mx_quant`), improving accuracy compared to the - online single-scale path. - -3. **Scale dtype**: fine scales are stored as `uint8` in safetensors (same bit - layout as `float8_e8m0fnu`) and are reinterpreted at load time without a - dtype conversion, avoiding a lossy float32 round-trip. - -4. **Self-attention QKV fusion**: the Q, K, V projection weights are fused into - a single `QKVParallelLinear` layer. Their `mul_scale` tensors are identical - (all three projections share the same input), so the three sequential loads - are idempotent. - -5. W4A4 carries inherently higher quantization noise than W8A8 (16 vs 256 - quantization levels). The DualScale offline method mitigates this with - calibrated `mul_scale` smooth quantization; online single-scale mode trades - accuracy for the convenience of not requiring a pre-processed checkpoint. +1. **Online single-scale (`mxfp4`)** quantizes BF16 weights at load time using + `npu_dynamic_mx_quant` (single-scale). No calibration `mul_scale` is + available — all output partitions receive an identity pre-scale. No offline + checkpoint format exists for this method. + +2. **Online dual-scale (`mxfp4_dualscale`, `is_checkpoint_serialized=False`)** + quantizes BF16 weights using `npu_dynamic_dual_level_mx_quant` (fine + coarse + scales computed on the fly). No calibration `mul_scale`; leading blocks or + explicit `ignored_layers` stay in BF16 for accuracy. + +3. **Offline dual-scale (`mxfp4_dualscale`, `is_checkpoint_serialized=True`)** — + **recommended for production** — loads four tensors per quantized layer: FP4 + weight, fine scale (`uint8` reinterpreted as `float8_e8m0fnu`), coarse scale + (`float32`), and per-input-channel `mul_scale` (`float32`). BF16 fallback + layers have no quantization tensors and are routed via `ignored_layers`. + +4. **Scale dtype**: fine scales are stored as `uint8` in safetensors (same bit + layout as `float8_e8m0fnu`) and reinterpreted at load time without a lossy + float32 round-trip. + +5. **Cascade model config propagation**: in a cascade model (transformer + + transformer_2), vLLM-Omni reads each transformer's own `config.json` and + rebuilds the quant config locally when `ignored_layers` differs between + transformers, ensuring per-layer routing is accurate for each. The first + transformer's config is propagated to `od_config` so the second transformer + can reuse it as a starting point. + +6. **Self-attention QKV fusion**: Q, K, V projection weights are fused into a + single `QKVParallelLinear` layer at runtime. `ignored_layers` entries use the + fused name (`attn1.to_qkv`), written automatically by the merge script. + +7. W4A4 carries higher quantization noise than W8A8 (16 vs 256 levels). The + DualScale offline method mitigates this with calibrated `mul_scale` smooth + quantization. Use `ignored_layers` and `num_bf16_fallback_layers` to trade + off compression vs. accuracy for precision-sensitive layers. diff --git a/docs/user_guide/quantization/overview.md b/docs/user_guide/quantization/overview.md index de6c6388231..b3e0c58b216 100644 --- a/docs/user_guide/quantization/overview.md +++ b/docs/user_guide/quantization/overview.md @@ -42,7 +42,7 @@ otherwise. | Int8 W8A8 | [Int8](int8.md) | Online or serialized W8A8 | Qwen-Image; Wan2.2 is not validated | Validated for Qwen-Image and Z-Image | | ModelOpt | [ModelOpt](modelopt.md) | Pre-quantized FP8 checkpoints | Qwen-Image, Z-Image, FLUX.2, HunyuanImage-3.0 | Validated for ModelOpt FP8 diffusion checkpoints | | MXFP8 W8A8 | [MXFP8](mxfp8.md) | Online W8A8 or offline pre-quantized | Wan2.2-T2V-A14B, I2V-A14B, TI2V-5B | Ascend NPU only; validated for Wan2.2 | -| MXFP4 W4A4 | [MXFP4](mxfp4.md) | Online W4A4 (single-scale) or offline DualScale pre-quantized | Wan2.2-T2V-A14B, I2V-A14B | Ascend NPU only; validated for Wan2.2 A14B cascade models; TI2V-5B not supported (accuracy loss too large); offline uses dual-scale with calibrated `mul_scale` | +| MXFP4 W4A4 | [MXFP4](mxfp4.md) | `mxfp4`: online single-scale only; `mxfp4_dualscale`: online or offline dual-scale (offline recommended) | Wan2.2-T2V-A14B, I2V-A14B | Ascend NPU only; validated for Wan2.2 A14B cascade models; TI2V-5B not supported; offline `mxfp4_dualscale` uses calibrated `mul_scale` for best accuracy | | GGUF | [GGUF](gguf.md) | Pre-quantized transformer weights | Qwen-Image | Validated where a model-specific GGUF adapter exists | | AutoRound | [AutoRound](autoround.md) | Pre-quantized W4A16 checkpoints | FLUX.1-dev; Qwen-Image/Wan2.2 not validated | Checkpoint-driven | | msModelSlim | [msModelSlim](msmodelslim.md) | Pre-quantized Ascend checkpoints | Wan2.2 recipe; HunyuanImage-3.0 inference target | Ascend/NPU path | diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index a4bcf8ce71c..f0438c41323 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -222,8 +222,8 @@ def parse_args() -> argparse.Namespace: "--quantization", type=str, default=None, - choices=["fp8", "mxfp8", "mxfp4", "mxfp8_mxfp4_dualscale", "int8", "gguf"], - help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 (NPU). mxfp4: W4A4 MXFP4 (NPU). mxfp8_mxfp4_dualscale: mixed MXFP8+MXFP4 dual-scale (NPU). fp8: online FP8 (GPU).", + choices=["fp8", "mxfp8", "mxfp4", "mxfp4_dualscale", "int8", "gguf"], + help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 (NPU). mxfp4: W4A4 MXFP4 (NPU). mxfp4_dualscale: W4A4 MXFP4 dual-scale + BF16 fallback mixed (NPU). fp8: online FP8 (GPU).", ) parser.add_argument( "--enable-diffusion-pipeline-profiler", diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 0df24baf152..a6d8f369519 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -206,8 +206,8 @@ def parse_args() -> argparse.Namespace: "--quantization", type=str, default=None, - choices=["fp8", "mxfp8", "mxfp4", "mxfp8_mxfp4_dualscale", "int8", "gguf"], - help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 (NPU). mxfp4: W4A4 MXFP4 (NPU). mxfp8_mxfp4_dualscale: mixed MXFP8+MXFP4 dual-scale (NPU). fp8: online FP8 (GPU).", + choices=["fp8", "mxfp8", "mxfp4", "mxfp4_dualscale", "int8", "gguf"], + help="Quantization method for the transformer. mxfp8: W8A8 MXFP8 (NPU). mxfp4: W4A4 MXFP4 (NPU). mxfp4_dualscale: W4A4 MXFP4 dual-scale + BF16 fallback mixed (NPU). fp8: online FP8 (GPU).", ) parser.add_argument( "--use-hsdp", diff --git a/tests/diffusion/models/wan2_2/test_create_transformer_quant.py b/tests/diffusion/models/wan2_2/test_create_transformer_quant.py index 4e32ef546cb..493c8f93555 100644 --- a/tests/diffusion/models/wan2_2/test_create_transformer_quant.py +++ b/tests/diffusion/models/wan2_2/test_create_transformer_quant.py @@ -9,7 +9,7 @@ - Auto-detects the quant method when no CLI quant_config is provided - Rejects method mismatches (CLI vs disk) - Upgrades online → offline when disk marks is_checkpoint_*_serialized=True - - Rebuilds when the active num_mxfp8_blocks differs from the disk value + - Rebuilds when the active ignored_layers differs from the disk value Wan22Pipeline._create_transformer (~L456) - Propagates the auto-detected config to od_config so the second transformer @@ -92,10 +92,10 @@ def test_create_transformer_detects_mxfp8_serialized_from_config_json(monkeypatc def test_create_transformer_detects_mxfp4_dualscale_from_config_json(monkeypatch): - """config.json with mxfp8_mxfp4_dualscale + num_mxfp8_blocks must produce - a DiffusionMXFP8MXFP4DualScaleConfig with the correct block count and + """config.json with mxfp4_dualscale + ignored_layers must produce + a DiffusionMXFP4DualScaleMixedConfig with the correct ignored_layers and serialized flag.""" - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig FakeTransformer, captured = _make_fake_transformer() monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) @@ -103,16 +103,16 @@ def test_create_transformer_detects_mxfp4_dualscale_from_config_json(monkeypatch config = { **_MIN_CFG, "quantization_config": { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 7, + "quant_method": "mxfp4_dualscale", + "ignored_layers": ["blocks.0.attn1.to_q", "blocks.0.attn1.to_k"], "is_checkpoint_serialized": True, }, } create_transformer_from_config(config) qc = captured[0].get("quant_config") - assert isinstance(qc, DiffusionMXFP8MXFP4DualScaleConfig) - assert qc.num_mxfp8_blocks == 7 + assert isinstance(qc, DiffusionMXFP4DualScaleMixedConfig) + assert set(qc.ignored_layers) == {"blocks.0.attn1.to_q", "blocks.0.attn1.to_k"} assert qc.is_checkpoint_serialized is True @@ -134,7 +134,11 @@ def test_create_transformer_without_quantization_config_passes_no_quant(monkeypa def test_create_transformer_rejects_method_mismatch(monkeypatch): """Passing a CLI quant_config whose get_name() differs from the config.json - quant_method must raise ValueError immediately (prevents silent weight corruption).""" + quant_method must raise ValueError immediately (prevents silent weight corruption). + + fp8 (vLLM built-in, get_name()=='fp8') vs disk 'mxfp8' triggers the guard. + These are distinct methods; using the same method for both would not trigger it. + """ from vllm.model_executor.layers.quantization.fp8 import Fp8Config FakeTransformer, _ = _make_fake_transformer() @@ -182,49 +186,55 @@ def test_create_transformer_upgrades_to_serialized_when_disk_marks_it(monkeypatc # --------------------------------------------------------------------------- -# create_transformer_from_config — num_mxfp8_blocks rebuild +# create_transformer_from_config — ignored_layers rebuild # --------------------------------------------------------------------------- -def test_create_transformer_rebuilds_when_num_mxfp8_blocks_differs(monkeypatch): - """When the active quant_config has num_mxfp8_blocks=5 but config.json says 10, - the config must be rebuilt from disk so block routing is authoritative for - this specific transformer.""" - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig +def test_create_transformer_rebuilds_when_ignored_layers_differ(monkeypatch): + """When the active quant_config has different ignored_layers than config.json, + the config must be rebuilt from disk so per-layer BF16 routing is authoritative + for this specific transformer.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig FakeTransformer, captured = _make_fake_transformer() monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) - stale = DiffusionMXFP8MXFP4DualScaleConfig(num_mxfp8_blocks=5, is_checkpoint_serialized=True) + stale = DiffusionMXFP4DualScaleMixedConfig( + is_checkpoint_serialized=True, + ignored_layers=["blocks.0.attn1.to_q"], + ) config = { **_MIN_CFG, "quantization_config": { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 10, + "quant_method": "mxfp4_dualscale", + "ignored_layers": ["blocks.0.attn1.to_q", "blocks.1.attn1.to_q"], "is_checkpoint_serialized": True, }, } create_transformer_from_config(config, quant_config=stale) qc = captured[0].get("quant_config") - assert isinstance(qc, DiffusionMXFP8MXFP4DualScaleConfig) - assert qc.num_mxfp8_blocks == 10 + assert isinstance(qc, DiffusionMXFP4DualScaleMixedConfig) + assert set(qc.ignored_layers) == {"blocks.0.attn1.to_q", "blocks.1.attn1.to_q"} -def test_create_transformer_does_not_rebuild_when_num_mxfp8_blocks_matches(monkeypatch): - """When the active quant_config already has the correct num_mxfp8_blocks, +def test_create_transformer_does_not_rebuild_when_ignored_layers_match(monkeypatch): + """When the active quant_config already has the same ignored_layers, the same instance must be passed through unchanged (no unnecessary rebuild).""" - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig FakeTransformer, captured = _make_fake_transformer() monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) - matching = DiffusionMXFP8MXFP4DualScaleConfig(num_mxfp8_blocks=5, is_checkpoint_serialized=True) + matching = DiffusionMXFP4DualScaleMixedConfig( + is_checkpoint_serialized=True, + ignored_layers=["blocks.0.attn1.to_q"], + ) config = { **_MIN_CFG, "quantization_config": { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 5, + "quant_method": "mxfp4_dualscale", + "ignored_layers": ["blocks.0.attn1.to_q"], "is_checkpoint_serialized": True, }, } @@ -320,16 +330,17 @@ def test_pipeline_cascade_both_transformers_get_mxfp8_serialized_config(monkeypa assert captured[0]["quant_config"] is captured[1]["quant_config"] -def test_pipeline_cascade_mxfp4_dualscale_each_transformer_gets_correct_num_blocks(monkeypatch): - """Cascade with mxfp8_mxfp4_dualscale where transformer and transformer_2 have - different num_mxfp8_blocks in their config.json. +def test_pipeline_cascade_mxfp4_dualscale_each_transformer_gets_correct_ignored_layers(monkeypatch): + """Cascade with mxfp4_dualscale where transformer and transformer_2 have + different ignored_layers in their config.json. Expected outcome: - transformer → num_mxfp8_blocks=5 (auto-detected, propagated to od_config) - transformer_2 → num_mxfp8_blocks=10 (rebuilt from disk because 10 ≠ 5) - od_config → num_mxfp8_blocks=5 (unchanged; transformer_2's rebuild is local) + transformer → ignored_layers=["blocks.0.attn1.to_q"] (auto-detected, propagated to od_config) + transformer_2 → ignored_layers=["blocks.0.attn1.to_q", "blocks.1.attn1.to_q"] + (rebuilt from disk because ignored_layers differ) + od_config → ignored_layers=["blocks.0.attn1.to_q"] (unchanged; rebuild was local) """ - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig FakeTransformer, captured = _make_fake_transformer() monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) @@ -340,16 +351,16 @@ def test_pipeline_cascade_mxfp4_dualscale_each_transformer_gets_correct_num_bloc cfg1 = { **_MIN_CFG, "quantization_config": { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 5, + "quant_method": "mxfp4_dualscale", + "ignored_layers": ["blocks.0.attn1.to_q"], "is_checkpoint_serialized": True, }, } cfg2 = { **_MIN_CFG, "quantization_config": { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 10, + "quant_method": "mxfp4_dualscale", + "ignored_layers": ["blocks.0.attn1.to_q", "blocks.1.attn1.to_q"], "is_checkpoint_serialized": True, }, } @@ -361,10 +372,13 @@ def test_pipeline_cascade_mxfp4_dualscale_each_transformer_gets_correct_num_bloc qc1 = captured[0].get("quant_config") qc2 = captured[1].get("quant_config") - assert isinstance(qc1, DiffusionMXFP8MXFP4DualScaleConfig) - assert isinstance(qc2, DiffusionMXFP8MXFP4DualScaleConfig) - assert qc1.num_mxfp8_blocks == 5, f"transformer expected 5 blocks, got {qc1.num_mxfp8_blocks}" - assert qc2.num_mxfp8_blocks == 10, f"transformer_2 expected 10 blocks, got {qc2.num_mxfp8_blocks}" + assert isinstance(qc1, DiffusionMXFP4DualScaleMixedConfig) + assert isinstance(qc2, DiffusionMXFP4DualScaleMixedConfig) + assert set(qc1.ignored_layers) == {"blocks.0.attn1.to_q"}, f"transformer expected 1 layer, got {qc1.ignored_layers}" + assert set(qc2.ignored_layers) == { + "blocks.0.attn1.to_q", + "blocks.1.attn1.to_q", + }, f"transformer_2 expected 2 layers, got {qc2.ignored_layers}" # od_config retains the first transformer's config; the rebuild was local. - assert od_config.quantization_config.num_mxfp8_blocks == 5 + assert set(od_config.quantization_config.ignored_layers) == {"blocks.0.attn1.to_q"} diff --git a/tests/diffusion/quantization/test_mxfp4_config.py b/tests/diffusion/quantization/test_mxfp4_config.py index 649298223da..411081c2260 100644 --- a/tests/diffusion/quantization/test_mxfp4_config.py +++ b/tests/diffusion/quantization/test_mxfp4_config.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for MXFP4 quantization configs and the mixed MXFP8+MXFP4 config.""" +"""Tests for MXFP4 quantization configs and the MXFP4 DualScale + BF16 mixed config.""" import pytest @@ -49,188 +49,298 @@ def test_mxfp4_config_from_config_modules_to_not_convert_fallback(): # --------------------------------------------------------------------------- -# DiffusionMXFP8MXFP4DualScaleConfig +# build_quant_config integration # --------------------------------------------------------------------------- -def test_mixed_config_get_name(): - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig +def test_build_quant_config_mxfp4_string(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config + + cfg = build_quant_config("mxfp4") + assert isinstance(cfg, DiffusionMXFP4Config) + assert cfg.get_name() == "mxfp4" + assert cfg.is_checkpoint_mxfp4_serialized is False + + +def test_build_quant_config_mxfp4_dict(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config - assert DiffusionMXFP8MXFP4DualScaleConfig.get_name() == "mxfp8_mxfp4_dualscale" + cfg = build_quant_config({"method": "mxfp4", "is_checkpoint_mxfp4_serialized": True}) + assert isinstance(cfg, DiffusionMXFP4Config) + assert cfg.is_checkpoint_mxfp4_serialized is True -def test_mixed_config_no_args_does_not_raise(): - """DiffusionMXFP8MXFP4DualScaleConfig() with no args must not raise TypeError.""" - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig +def test_build_quant_config_mxfp4_dualscale_string(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig - cfg = DiffusionMXFP8MXFP4DualScaleConfig() - assert cfg.num_mxfp8_blocks == 0 + cfg = build_quant_config("mxfp4_dualscale") + assert isinstance(cfg, DiffusionMXFP4DualScaleMixedConfig) assert cfg.is_checkpoint_serialized is False + assert cfg.num_bf16_fallback_layers == 5 assert cfg.ignored_layers == [] -def test_mixed_config_from_config_with_num_blocks(): - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig +def test_build_quant_config_mxfp4_dualscale_dict_offline(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig - cfg = DiffusionMXFP8MXFP4DualScaleConfig.from_config( + cfg = build_quant_config( { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 5, + "method": "mxfp4_dualscale", "is_checkpoint_serialized": True, + "ignored_layers": ["blocks.0.attn1.to_q", "blocks.0.attn1.to_k"], } ) - assert cfg.num_mxfp8_blocks == 5 + assert isinstance(cfg, DiffusionMXFP4DualScaleMixedConfig) assert cfg.is_checkpoint_serialized is True - assert cfg.ignored_layers == [] + assert cfg.ignored_layers == ["blocks.0.attn1.to_q", "blocks.0.attn1.to_k"] -def test_mixed_config_from_config_defaults(): - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig +def test_build_quant_config_mxfp4_dualscale_dict_online_custom_fallback(): + from vllm_omni.quantization import build_quant_config + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig - cfg = DiffusionMXFP8MXFP4DualScaleConfig.from_config({}) - assert cfg.num_mxfp8_blocks == 0 - assert cfg.is_checkpoint_serialized is False + cfg = build_quant_config({"method": "mxfp4_dualscale", "num_bf16_fallback_layers": 10}) + assert isinstance(cfg, DiffusionMXFP4DualScaleMixedConfig) + assert cfg.num_bf16_fallback_layers == 10 -def test_mixed_config_from_config_ignored_layers(): - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig +# --------------------------------------------------------------------------- +# Block-index dispatch (_parse_block_idx) +# --------------------------------------------------------------------------- - cfg = DiffusionMXFP8MXFP4DualScaleConfig.from_config({"num_mxfp8_blocks": 3, "ignored_layers": ["proj_out"]}) - assert cfg.ignored_layers == ["proj_out"] + +def test_parse_block_idx_valid(): + from vllm_omni.quantization.mxfp4_config import _parse_block_idx + + assert _parse_block_idx("blocks.0.attn1.to_q") == 0 + assert _parse_block_idx("blocks.5.ffn.net.0.proj") == 5 + assert _parse_block_idx("blocks.40.norm1.weight") == 40 + + +def test_parse_block_idx_non_block_prefixes(): + """Prefixes that do not start with 'blocks.N.' must return None.""" + from vllm_omni.quantization.mxfp4_config import _parse_block_idx + + assert _parse_block_idx("condition_embedder.time_embedder.linear_1") is None + assert _parse_block_idx("proj_out.weight") is None + assert _parse_block_idx("model.layers.0.self_attn.q_proj") is None + assert _parse_block_idx("scale_shift_table") is None # --------------------------------------------------------------------------- -# build_quant_config integration +# SUPPORTED_QUANTIZATION_METHODS # --------------------------------------------------------------------------- -def test_build_quant_config_mxfp4_string(): - from vllm_omni.quantization import build_quant_config - from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config +def test_supported_methods_include_mxfp4_variants(): + from vllm_omni.quantization import SUPPORTED_QUANTIZATION_METHODS - cfg = build_quant_config("mxfp4") - assert isinstance(cfg, DiffusionMXFP4Config) - assert cfg.get_name() == "mxfp4" - assert cfg.is_checkpoint_mxfp4_serialized is False + assert "mxfp4" in SUPPORTED_QUANTIZATION_METHODS + assert "mxfp8" in SUPPORTED_QUANTIZATION_METHODS + assert "mxfp4_dualscale" in SUPPORTED_QUANTIZATION_METHODS -def test_build_quant_config_mxfp4_dict(): - from vllm_omni.quantization import build_quant_config - from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config +# --------------------------------------------------------------------------- +# DiffusionMXFP4DualScaleMixedConfig — config roundtrips +# --------------------------------------------------------------------------- - cfg = build_quant_config({"method": "mxfp4", "is_checkpoint_mxfp4_serialized": True}) - assert isinstance(cfg, DiffusionMXFP4Config) - assert cfg.is_checkpoint_mxfp4_serialized is True +def test_mixed_dualscale_config_get_name(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig -def test_build_quant_config_mxfp8_mxfp4_dualscale_dict(): - from vllm_omni.quantization import build_quant_config - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + assert DiffusionMXFP4DualScaleMixedConfig.get_name() == "mxfp4_dualscale" - cfg = build_quant_config( + +def test_mixed_dualscale_config_no_args_defaults(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig + + cfg = DiffusionMXFP4DualScaleMixedConfig() + assert cfg.is_checkpoint_serialized is False + assert cfg.ignored_layers == [] + assert cfg.num_bf16_fallback_layers == 5 + + +def test_mixed_dualscale_config_from_config_offline(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig + + cfg = DiffusionMXFP4DualScaleMixedConfig.from_config( { - "method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 5, + "quant_method": "mxfp4_dualscale", "is_checkpoint_serialized": True, + "ignored_layers": ["blocks.0.attn1.to_q", "proj_out"], } ) - assert isinstance(cfg, DiffusionMXFP8MXFP4DualScaleConfig) - assert cfg.num_mxfp8_blocks == 5 assert cfg.is_checkpoint_serialized is True + assert cfg.ignored_layers == ["blocks.0.attn1.to_q", "proj_out"] + assert cfg.num_bf16_fallback_layers == 5 # default -def test_build_quant_config_mxfp8_mxfp4_dualscale_warns_without_num_blocks( - monkeypatch: pytest.MonkeyPatch, -): - """build_quant_config('mxfp8_mxfp4_dualscale') must emit WARNING when - num_mxfp8_blocks is absent and default to 0 (all-MXFP4 DualScale mode). +def test_mixed_dualscale_config_from_config_online_custom_fallback(): + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig - Uses monkeypatch instead of caplog because vllm's init_logger may configure - propagation in a way that prevents caplog from intercepting the messages. - """ - import vllm_omni.quantization.factory as factory_module - from vllm_omni.quantization import build_quant_config - from vllm_omni.quantization.mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + cfg = DiffusionMXFP4DualScaleMixedConfig.from_config({"num_bf16_fallback_layers": 10}) + assert cfg.is_checkpoint_serialized is False + assert cfg.num_bf16_fallback_layers == 10 + + +def test_mixed_dualscale_config_from_config_modules_to_not_convert_fallback(): + """modules_to_not_convert must be accepted as an alias for ignored_layers.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig - warning_messages: list[str] = [] - monkeypatch.setattr( - factory_module.logger, - "warning", - lambda msg, *args, **kw: warning_messages.append(msg), + cfg = DiffusionMXFP4DualScaleMixedConfig.from_config( + {"is_checkpoint_serialized": True, "modules_to_not_convert": ["proj_out"]} ) + assert cfg.ignored_layers == ["proj_out"] - cfg = build_quant_config("mxfp8_mxfp4_dualscale") - assert isinstance(cfg, DiffusionMXFP8MXFP4DualScaleConfig) - assert cfg.num_mxfp8_blocks == 0 - assert len(warning_messages) >= 1 - assert any("num_mxfp8_blocks" in msg for msg in warning_messages) +# --------------------------------------------------------------------------- +# DiffusionMXFP4DualScaleMixedConfig — get_quant_method dispatch +# --------------------------------------------------------------------------- -def test_build_quant_config_mxfp8_mxfp4_dualscale_info_with_num_blocks( +def test_mixed_dualscale_offline_ignored_layer_returns_unquantized( + mocker, monkeypatch: pytest.MonkeyPatch, ): - """When num_mxfp8_blocks is provided, INFO is logged and no WARNING is emitted.""" - import vllm_omni.quantization.factory as factory_module - from vllm_omni.quantization import build_quant_config + """Offline: a prefix in ignored_layers must return UnquantizedLinearMethod.""" + from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod - warning_messages: list[str] = [] - info_messages: list[str] = [] - monkeypatch.setattr( - factory_module.logger, - "warning", - lambda msg, *args, **kw: warning_messages.append(msg), + from vllm_omni.platforms import current_omni_platform + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig + + cfg = DiffusionMXFP4DualScaleMixedConfig( + is_checkpoint_serialized=True, + ignored_layers=["blocks.0.attn1.to_q"], ) - monkeypatch.setattr( - factory_module.logger, - "info", - lambda msg, *args, **kw: info_messages.append(msg), + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + + method = cfg.get_quant_method(layer, "blocks.0.attn1.to_q") + assert isinstance(method, UnquantizedLinearMethod) + + +def test_mixed_dualscale_offline_non_ignored_returns_mxfp4( + mocker, + monkeypatch: pytest.MonkeyPatch, +): + """Offline: a prefix NOT in ignored_layers must return NPUMxfp4DualScaleLinearMethod.""" + from vllm.model_executor.layers.linear import LinearBase + + from vllm_omni.platforms import current_omni_platform + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig, NPUMxfp4DualScaleLinearMethod + + cfg = DiffusionMXFP4DualScaleMixedConfig( + is_checkpoint_serialized=True, + ignored_layers=["blocks.0.attn1.to_q"], ) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) - cfg = build_quant_config( - { - "method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": 5, - "is_checkpoint_serialized": True, - } + method = cfg.get_quant_method(layer, "blocks.1.attn1.to_q") + assert isinstance(method, NPUMxfp4DualScaleLinearMethod) + + +def test_mixed_dualscale_online_fallback_block_returns_unquantized( + mocker, + monkeypatch: pytest.MonkeyPatch, +): + """Online: blocks < num_bf16_fallback_layers must return UnquantizedLinearMethod.""" + from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod + + from vllm_omni.platforms import current_omni_platform + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig + + cfg = DiffusionMXFP4DualScaleMixedConfig(is_checkpoint_serialized=False, num_bf16_fallback_layers=5) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + + assert isinstance(cfg.get_quant_method(layer, "blocks.0.attn1.to_q"), UnquantizedLinearMethod) + assert isinstance(cfg.get_quant_method(layer, "blocks.4.ffn.net.0.proj"), UnquantizedLinearMethod) + + +def test_mixed_dualscale_online_quantized_block_returns_mxfp4( + mocker, + monkeypatch: pytest.MonkeyPatch, +): + """Online: blocks >= num_bf16_fallback_layers must return NPUMxfp4DualScaleOnlineLinearMethod.""" + from vllm.model_executor.layers.linear import LinearBase + + from vllm_omni.platforms import current_omni_platform + from vllm_omni.quantization.mxfp4_config import ( + DiffusionMXFP4DualScaleMixedConfig, + NPUMxfp4DualScaleOnlineLinearMethod, ) - assert cfg.num_mxfp8_blocks == 5 - assert len(warning_messages) == 0 - assert any("num_mxfp8_blocks" in msg for msg in info_messages) + cfg = DiffusionMXFP4DualScaleMixedConfig(is_checkpoint_serialized=False, num_bf16_fallback_layers=5) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + assert isinstance(cfg.get_quant_method(layer, "blocks.5.attn1.to_q"), NPUMxfp4DualScaleOnlineLinearMethod) + assert isinstance(cfg.get_quant_method(layer, "blocks.40.ffn.net.0.proj"), NPUMxfp4DualScaleOnlineLinearMethod) -# --------------------------------------------------------------------------- -# Block-index dispatch (_parse_block_idx) -# --------------------------------------------------------------------------- +def test_mixed_dualscale_online_non_block_prefix_returns_mxfp4( + mocker, + monkeypatch: pytest.MonkeyPatch, +): + """Online: layers outside 'blocks.N.*' (condition_embedder etc.) always use MXFP4 online.""" + from vllm.model_executor.layers.linear import LinearBase -def test_parse_block_idx_valid(): - from vllm_omni.quantization.mixed_mxfp_config import _parse_block_idx + from vllm_omni.platforms import current_omni_platform + from vllm_omni.quantization.mxfp4_config import ( + DiffusionMXFP4DualScaleMixedConfig, + NPUMxfp4DualScaleOnlineLinearMethod, + ) - assert _parse_block_idx("blocks.0.attn1.to_q") == 0 - assert _parse_block_idx("blocks.5.ffn.net.0.proj") == 5 - assert _parse_block_idx("blocks.40.norm1.weight") == 40 + cfg = DiffusionMXFP4DualScaleMixedConfig(is_checkpoint_serialized=False, num_bf16_fallback_layers=5) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + method = cfg.get_quant_method(layer, "condition_embedder.time_embedder.linear_1") + assert isinstance(method, NPUMxfp4DualScaleOnlineLinearMethod) -def test_parse_block_idx_non_block_prefixes(): - """Prefixes that do not start with 'blocks.N.' must return None.""" - from vllm_omni.quantization.mixed_mxfp_config import _parse_block_idx - assert _parse_block_idx("condition_embedder.time_embedder.linear_1") is None - assert _parse_block_idx("proj_out.weight") is None - assert _parse_block_idx("model.layers.0.self_attn.q_proj") is None - assert _parse_block_idx("scale_shift_table") is None +def test_mixed_dualscale_online_ignored_layers_override( + mocker, + monkeypatch: pytest.MonkeyPatch, +): + """Online: explicit ignored_layers must return UnquantizedLinearMethod regardless of block index. + A layer that is NOT in the leading-block range (block 10 >= num_bf16_fallback_layers=5) + but IS listed in ignored_layers must still fall back to BF16. This lets power users + pin specific interleaved layers to BF16 during online quantization without needing an + offline checkpoint. + """ + from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod -# --------------------------------------------------------------------------- -# SUPPORTED_QUANTIZATION_METHODS -# --------------------------------------------------------------------------- + from vllm_omni.platforms import current_omni_platform + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig + cfg = DiffusionMXFP4DualScaleMixedConfig( + is_checkpoint_serialized=False, + num_bf16_fallback_layers=5, + ignored_layers=["blocks.10.attn1.to_q"], + ) + layer = mocker.Mock(spec=LinearBase) + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) -def test_supported_methods_include_mxfp4_variants(): - from vllm_omni.quantization import SUPPORTED_QUANTIZATION_METHODS + # block 10 is above the leading-block threshold but is in ignored_layers → BF16 + assert isinstance(cfg.get_quant_method(layer, "blocks.10.attn1.to_q"), UnquantizedLinearMethod) - assert "mxfp4" in SUPPORTED_QUANTIZATION_METHODS - assert "mxfp8" in SUPPORTED_QUANTIZATION_METHODS - assert "mxfp8_mxfp4_dualscale" in SUPPORTED_QUANTIZATION_METHODS + +def test_mixed_dualscale_non_linear_returns_none(monkeypatch: pytest.MonkeyPatch): + """Non-LinearBase layers (norms, embeddings) must return None → no quantization.""" + import torch + + from vllm_omni.platforms import current_omni_platform + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig + + cfg = DiffusionMXFP4DualScaleMixedConfig() + monkeypatch.setattr(current_omni_platform, "is_npu", lambda: True) + + norm_layer = torch.nn.LayerNorm(64) + assert cfg.get_quant_method(norm_layer, "blocks.0.norm1") is None diff --git a/tests/diffusion/quantization/test_mxfp4_key_remap.py b/tests/diffusion/quantization/test_mxfp4_key_remap.py index 4f7edfded5f..89e1a4e0890 100644 --- a/tests/diffusion/quantization/test_mxfp4_key_remap.py +++ b/tests/diffusion/quantization/test_mxfp4_key_remap.py @@ -7,10 +7,10 @@ """ import pytest +import torch pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] - # --------------------------------------------------------------------------- # SUPPORTED_MODEL_TYPES # --------------------------------------------------------------------------- @@ -167,35 +167,223 @@ def test_classify_blocks_empty(): # --------------------------------------------------------------------------- -# _detect_num_mxfp8_blocks +# _is_mxfp4_tensor # --------------------------------------------------------------------------- -def test_detect_num_mxfp8_blocks_with_mxfp8_present(): - from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _detect_num_mxfp8_blocks +def test_is_mxfp4_tensor_quantized_weight(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _is_mxfp4_tensor - quant_meta = { - "blocks.0.attn1.to_q.weight": "W8A8_MXFP8", - "blocks.1.attn1.to_q.weight": "W8A8_MXFP8", - "blocks.2.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", - "blocks.3.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", + assert _is_mxfp4_tensor("blocks.0.attn1.to_q.linear.weight", "W4A4_MXFP4_DUALSCALE") is True + assert _is_mxfp4_tensor("blocks.0.attn1.to_q.linear.weight_scale", "W4A4_MXFP4_DUALSCALE") is True + assert _is_mxfp4_tensor("blocks.0.attn1.to_q.linear.weight_dual_scale", "W4A4_MXFP4_DUALSCALE") is True + + +def test_is_mxfp4_tensor_companion_bias_and_mul_scale(): + """Companion tensors (bias, mul_scale) are FLOAT but belong to MXFP4 layers + and must be included in the merged checkpoint.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _is_mxfp4_tensor + + # bias inside .linear. wrapper → must be included + assert _is_mxfp4_tensor("blocks.0.attn1.to_q.linear.bias", "FLOAT") is True + # mul_scale inside .div. wrapper → must be included + assert _is_mxfp4_tensor("blocks.0.attn1.to_q.div.mul_scale", "FLOAT") is True + + +def test_is_mxfp4_tensor_bf16_fallback_weight(): + """BF16 fallback linear layers have plain .weight keys (no wrapper) → must NOT be included.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _is_mxfp4_tensor + + assert _is_mxfp4_tensor("blocks.0.attn1.to_q.weight", "FLOAT") is False + assert _is_mxfp4_tensor("condition_embedder.time_embedder.linear_1.weight", "FLOAT") is False + + +def test_is_mxfp4_tensor_norm_weight(): + """Norm layers are always FLOAT with plain .weight → not MXFP4 related.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _is_mxfp4_tensor + + assert _is_mxfp4_tensor("blocks.0.norm1.weight", "FLOAT") is False + + +# --------------------------------------------------------------------------- +# _collect_ignored_layers +# --------------------------------------------------------------------------- + + +def test_collect_ignored_layers_basic(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _collect_ignored_layers + + merged = { + "blocks.0.attn1.to_q.weight": torch.zeros(4, 4), + "blocks.0.attn1.to_k.weight": torch.zeros(4, 4), + "blocks.1.attn1.to_q.weight": torch.zeros(4, 4), + "blocks.1.attn1.to_q.weight_scale": torch.zeros(4, 1), } - assert _detect_num_mxfp8_blocks(quant_meta) == 2 + mxfp4_prefixes = {"blocks.1.attn1.to_q"} + ignored = _collect_ignored_layers(merged, mxfp4_prefixes) + assert "blocks.0.attn1.to_q" in ignored + assert "blocks.0.attn1.to_k" in ignored + assert "blocks.1.attn1.to_q" not in ignored # MXFP4, not ignored -def test_detect_num_mxfp8_blocks_without_mxfp8_keys(): - """When MXFP8 blocks are absent from quant_meta, the first MXFP4 block - index equals num_mxfp8_blocks (msModelSlim may omit MXFP8 markers).""" - from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _detect_num_mxfp8_blocks - quant_meta = { - "blocks.3.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", - "blocks.4.attn1.to_q.weight": "W4A4_MXFP4_DUALSCALE", +def test_collect_ignored_layers_empty_mxfp4(): + """When no layers are MXFP4 (all BF16), all .weight prefixes become ignored_layers.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _collect_ignored_layers + + merged = { + "blocks.0.attn1.to_q.weight": torch.zeros(4, 4), + "proj_out.weight": torch.zeros(4, 4), } - assert _detect_num_mxfp8_blocks(quant_meta) == 3 + ignored = _collect_ignored_layers(merged, set()) + assert sorted(ignored) == ["blocks.0.attn1.to_q", "proj_out"] + + +def test_collect_ignored_layers_returns_sorted(): + """ignored_layers must be sorted for deterministic config.json output.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _collect_ignored_layers + merged = { + "blocks.2.attn1.to_q.weight": torch.zeros(1), + "blocks.0.attn1.to_q.weight": torch.zeros(1), + "blocks.1.attn1.to_q.weight": torch.zeros(1), + } + ignored = _collect_ignored_layers(merged, set()) + assert ignored == sorted(ignored) + + +# --------------------------------------------------------------------------- +# _build_quant_config (new mxfp4_dualscale format) +# --------------------------------------------------------------------------- + + +def test_build_quant_config_new_format(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _build_quant_config + + config = _build_quant_config(["blocks.0.attn1.to_q", "proj_out"]) + assert config["quant_method"] == "mxfp4_dualscale" + assert config["is_checkpoint_serialized"] is True + assert config["ignored_layers"] == ["blocks.0.attn1.to_q", "proj_out"] + assert "num_mxfp8_blocks" not in config + + +def test_build_quant_config_empty_ignored_layers(): + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _build_quant_config + + config = _build_quant_config([]) + assert config["ignored_layers"] == [] + + +# --------------------------------------------------------------------------- +# _diffusers_to_vllm_ignored +# --------------------------------------------------------------------------- -def test_detect_num_mxfp8_blocks_empty(): - from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _detect_num_mxfp8_blocks - assert _detect_num_mxfp8_blocks({}) == 0 +def test_remap_self_attn_qkv_fusion(): + """All three of attn1.to_q/k/v in BF16 → fused attn1.to_qkv.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + result = _diffusers_to_vllm_ignored(["blocks.0.attn1.to_k", "blocks.0.attn1.to_q", "blocks.0.attn1.to_v"]) + assert result == ["blocks.0.attn1.to_qkv"] + + +def test_remap_self_attn_qkv_partial_not_fused(): + """Only some of Q/K/V in ignored — cannot fuse; kept as individual names.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + result = _diffusers_to_vllm_ignored(["blocks.0.attn1.to_q", "blocks.0.attn1.to_k"]) + # Only q and k present, not v → no fusion + assert "blocks.0.attn1.to_qkv" not in result + assert "blocks.0.attn1.to_q" in result + assert "blocks.0.attn1.to_k" in result + + +def test_remap_cross_attn_kept_separate(): + """Cross-attention (attn2) to_q/k/v must NOT be fused.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + result = _diffusers_to_vllm_ignored(["blocks.0.attn2.to_k", "blocks.0.attn2.to_q", "blocks.0.attn2.to_v"]) + assert "blocks.0.attn2.to_qkv" not in result + assert "blocks.0.attn2.to_q" in result + assert "blocks.0.attn2.to_k" in result + assert "blocks.0.attn2.to_v" in result + + +def test_remap_ffn_dot_to_underscore(): + """ffn.net.0.proj → ffn.net_0.proj; ffn.net.2 → ffn.net_2.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + result = _diffusers_to_vllm_ignored(["blocks.0.ffn.net.0.proj", "blocks.0.ffn.net.2"]) + assert "blocks.0.ffn.net_0.proj" in result + assert "blocks.0.ffn.net_2" in result + assert "blocks.0.ffn.net.0.proj" not in result + assert "blocks.0.ffn.net.2" not in result + + +def test_remap_to_out_strips_index(): + """attn1.to_out.0 and attn2.to_out.0 must have the .0 suffix removed.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + result = _diffusers_to_vllm_ignored(["blocks.0.attn1.to_out.0", "blocks.0.attn2.to_out.0"]) + assert "blocks.0.attn1.to_out" in result + assert "blocks.0.attn2.to_out" in result + assert "blocks.0.attn1.to_out.0" not in result + assert "blocks.0.attn2.to_out.0" not in result + + +def test_remap_noop_for_non_block_layers(): + """Layers outside 'blocks.N.*' (condition_embedder, proj_out) pass through unchanged.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + inputs = [ + "condition_embedder.time_embedder.linear_1", + "condition_embedder.text_embedder.linear_1", + "proj_out", + ] + result = _diffusers_to_vllm_ignored(inputs) + assert result == sorted(inputs) + + +def test_remap_mixed_block_layers(): + """Typical mixed block: QKV fused, to_out stripped, FFN renamed, cross-attn separate.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + inputs = [ + "blocks.0.attn1.to_k", + "blocks.0.attn1.to_out.0", + "blocks.0.attn1.to_q", + "blocks.0.attn1.to_v", + "blocks.0.attn2.to_k", + "blocks.0.attn2.to_out.0", + "blocks.0.attn2.to_q", + "blocks.0.attn2.to_v", + "blocks.0.ffn.net.0.proj", + "blocks.0.ffn.net.2", + ] + result = _diffusers_to_vllm_ignored(inputs) + assert "blocks.0.attn1.to_qkv" in result + assert "blocks.0.attn1.to_out" in result + assert "blocks.0.attn2.to_q" in result + assert "blocks.0.attn2.to_k" in result + assert "blocks.0.attn2.to_v" in result + assert "blocks.0.attn2.to_out" in result + assert "blocks.0.ffn.net_0.proj" in result + assert "blocks.0.ffn.net_2" in result + # Old names must be gone + assert "blocks.0.attn1.to_q" not in result + assert "blocks.0.ffn.net.0.proj" not in result + + +def test_remap_returns_sorted(): + """Output of _diffusers_to_vllm_ignored must always be sorted.""" + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored + + inputs = [ + "blocks.2.ffn.net.0.proj", + "blocks.0.attn1.to_q", + "blocks.0.attn1.to_k", + "blocks.0.attn1.to_v", + "blocks.1.attn2.to_out.0", + ] + result = _diffusers_to_vllm_ignored(inputs) + assert result == sorted(result) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 4a4581787cb..fd1a6d6bf4e 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -212,19 +212,16 @@ def create_transformer_from_config( qc_method, ) elif ( - "num_mxfp8_blocks" in qc_kwargs - and hasattr(quant_config, "num_mxfp8_blocks") - and qc_kwargs["num_mxfp8_blocks"] != quant_config.num_mxfp8_blocks + "ignored_layers" in qc_kwargs + and hasattr(quant_config, "ignored_layers") + and set(qc_kwargs.get("ignored_layers") or []) != set(quant_config.ignored_layers or []) ): - # The transformer's own config.json has a different num_mxfp8_blocks than - # the active quant_config (e.g. built from a stale enriched config or a - # different transformer's config.json in a cascade model). Rebuild from - # disk so the block routing is authoritative for this transformer. + # mxfp4_dualscale cascade: each transformer may have different BF16 fallback + # layers in its config.json. Rebuild so per-transformer routing is correct. quant_config = build_quant_config(qc_method, **qc_kwargs) logger.info( - "Disk config.json num_mxfp8_blocks=%d differs from active config; " + "Disk config.json ignored_layers differs from active config; " "rebuilding quant_config from transformer config.json.", - qc_kwargs["num_mxfp8_blocks"], ) elif isinstance(disk_qc, str) and quant_config is None: quant_config = build_quant_config(disk_qc) @@ -461,8 +458,8 @@ def _create_transformer(self, config: dict) -> WanTransformer3DModel: # the quant_config from the transformer's own config.json and propagate it back to # od_config. This has two effects: # 1. The first transformer's auto-detected config is reused by the second transformer - # in cascade models (e.g. Wan2.2-T2V-A14B), preventing stale/wrong num_mxfp8_blocks - # from an independent read of transformer_2/config.json. + # in cascade models (e.g. Wan2.2-T2V-A14B); if the second transformer's config.json + # has different ignored_layers, create_transformer_from_config rebuilds locally. # 2. od_config.quantization_config becomes non-None so _check_unloaded_weights can # filter expected quantization suffixes instead of raising on every unloaded param. if quant_config is None and "quantization_config" in config: diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py index fa51324b61b..acdb1eb4586 100644 --- a/vllm_omni/quantization/factory.py +++ b/vllm_omni/quantization/factory.py @@ -55,35 +55,20 @@ def _build_mxfp4(**kw: Any) -> QuantizationConfig: return DiffusionMXFP4Config(**kw) -def _build_mxfp8_mxfp4_dualscale(**kw: Any) -> QuantizationConfig: - """Lazy import for MXFP8 (early blocks) + MXFP4 dual-scale (later blocks) config (NPU only). - - This method is checkpoint-topology-dependent: num_mxfp8_blocks is normally - injected into transformer/config.json by merge_mxfp4_dualscale_checkpoint.py - and auto-detected from there during offline checkpoint loading. - If invoked without num_mxfp8_blocks (e.g. via --quantization mxfp8_mxfp4_dualscale - on a BF16 checkpoint), num_mxfp8_blocks defaults to 0 and all blocks fall through - to the MXFP4 DualScale online path — the MXFP8 branch is never selected. +def _build_mxfp4_dualscale(**kw: Any) -> QuantizationConfig: + """Lazy import for MXFP4 DualScale + BF16 mixed diffusion config (NPU only). + + Offline mode (is_checkpoint_serialized=True): + ignored_layers from config.json marks interleaved BF16 fallback layers. + All other linear layers use W4A4 MXFP4 DualScale. + + Online mode (is_checkpoint_serialized=False): + num_bf16_fallback_layers leading transformer blocks use BF16 original weights + (default 5 when not specified). Remaining blocks use W4A4 MXFP4 DualScale online. """ - if "num_mxfp8_blocks" not in kw: - logger.warning( - "'mxfp8_mxfp4_dualscale' was requested without num_mxfp8_blocks. " - "Defaulting to num_mxfp8_blocks=0: all transformer blocks will use " - "MXFP4 DualScale online quantization and no MXFP8 blocks will be applied. " - "This mode is not recommended for online (BF16 checkpoint) use. " - "For the intended mixed MXFP8+MXFP4 mode, use a pre-quantized checkpoint " - "produced by merge_mxfp4_dualscale_checkpoint.py and omit --quantization " - "to let vllm-omni auto-detect num_mxfp8_blocks from transformer/config.json." - ) - else: - logger.info( - "Building mxfp8_mxfp4_dualscale config: num_mxfp8_blocks=%d, is_checkpoint_serialized=%s", - kw["num_mxfp8_blocks"], - kw.get("is_checkpoint_serialized", False), - ) - from .mixed_mxfp_config import DiffusionMXFP8MXFP4DualScaleConfig + from .mxfp4_config import DiffusionMXFP4DualScaleMixedConfig - return DiffusionMXFP8MXFP4DualScaleConfig(**kw) + return DiffusionMXFP4DualScaleMixedConfig(**kw) def _build_inc(**kw: Any) -> QuantizationConfig: @@ -105,7 +90,7 @@ def _build_inc(**kw: Any) -> QuantizationConfig: "int8": _build_int8, "mxfp8": _build_mxfp8, "mxfp4": _build_mxfp4, - "mxfp8_mxfp4_dualscale": _build_mxfp8_mxfp4_dualscale, + "mxfp4_dualscale": _build_mxfp4_dualscale, "inc": _build_inc, "auto-round": _build_inc, "auto_round": _build_inc, diff --git a/vllm_omni/quantization/mixed_mxfp_config.py b/vllm_omni/quantization/mixed_mxfp_config.py deleted file mode 100644 index b4123e52964..00000000000 --- a/vllm_omni/quantization/mixed_mxfp_config.py +++ /dev/null @@ -1,149 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Mixed-precision quantization configs for diffusion transformers. - -Each class in this file describes one specific combination of quantization methods -applied to different transformer blocks. New combinations should be added here. - -Current configs ---------------- -DiffusionMXFP8MXFP4DualScaleConfig ("mxfp8_mxfp4_dualscale") - Blocks 0..num_mxfp8_blocks-1 → W8A8 MXFP8 - Blocks num_mxfp8_blocks.. → W4A4 MXFP4 dual-scale - - Block-index dispatch requires linear layers to be constructed with a prefix - of the form "blocks.N.*", threaded through WanTransformerBlock in - wan2_2_transformer.py. - - Config injected by merge_mixed_mxfp_checkpoint.py: - { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": , - "is_checkpoint_serialized": true - } -""" - -from __future__ import annotations - -import re -from typing import TYPE_CHECKING, Any - -import torch -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped - -from vllm_omni.platforms import current_omni_platform -from vllm_omni.quantization.mxfp4_config import ( - NPUMxfp4DualScaleLinearMethod, - NPUMxfp4DualScaleOnlineLinearMethod, -) -from vllm_omni.quantization.mxfp8_config import NPUMxfp8LinearMethod, NPUMxfp8OnlineLinearMethod - -if TYPE_CHECKING: - from vllm.model_executor.models.utils import WeightsMapper - -logger = init_logger(__name__) - -_BLOCK_IDX_RE = re.compile(r"^blocks\.(\d+)\.") - - -def _parse_block_idx(prefix: str) -> int | None: - """Extract block index from prefix like 'blocks.5.attn1.to_q'.""" - m = _BLOCK_IDX_RE.match(prefix) - return int(m.group(1)) if m else None - - -class DiffusionMXFP8MXFP4DualScaleConfig(QuantizationConfig): - """W8A8 MXFP8 (early blocks) + W4A4 MXFP4 dual-scale (remaining blocks). - - Blocks 0 .. num_mxfp8_blocks-1 are quantized with MXFP8. - Blocks num_mxfp8_blocks .. end are quantized with MXFP4 dual-scale. - - offline mode (is_checkpoint_serialized=True): - MXFP8 blocks → NPUMxfp8LinearMethod - MXFP4 blocks → NPUMxfp4DualScaleLinearMethod - - online mode (is_checkpoint_serialized=False): - MXFP8 blocks → NPUMxfp8OnlineLinearMethod - MXFP4 blocks → NPUMxfp4DualScaleOnlineLinearMethod - - Layers with a prefix not matching "blocks.N.*" (e.g. condition_embedder) are - treated as outside the MXFP8 range and fall through to the MXFP4 dual-scale path. - """ - - def __init__( - self, - num_mxfp8_blocks: int = 0, - is_checkpoint_serialized: bool = False, - ignored_layers: list[str] | None = None, - ) -> None: - super().__init__() - self.num_mxfp8_blocks = num_mxfp8_blocks - self.is_checkpoint_serialized = is_checkpoint_serialized - self.ignored_layers = ignored_layers or [] - - @classmethod - def get_name(cls) -> QuantizationMethods: - return "mxfp8_mxfp4_dualscale" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.bfloat16, torch.float16] - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> list[str]: - return [] - - def apply_vllm_mapper(self, hf_to_vllm_mapper: WeightsMapper) -> None: - if self.ignored_layers: - self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) - - @classmethod - def from_config(cls, config: dict[str, Any]) -> DiffusionMXFP8MXFP4DualScaleConfig: - num_mxfp8_blocks = cls.get_from_keys_or(config, ["num_mxfp8_blocks"], 0) - is_serialized = cls.get_from_keys_or(config, ["is_checkpoint_serialized"], False) - ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) - if not ignored_layers: - ignored_layers = cls.get_from_keys_or(config, ["modules_to_not_convert"], None) - return cls( - num_mxfp8_blocks=num_mxfp8_blocks, - is_checkpoint_serialized=is_serialized, - ignored_layers=ignored_layers, - ) - - def get_quant_method( - self, - layer: torch.nn.Module, - prefix: str, - ) -> QuantizeMethodBase | None: - if not isinstance(layer, LinearBase): - return None - if is_layer_skipped( - prefix=prefix, - ignored_layers=self.ignored_layers, - fused_mapping=self.packed_modules_mapping, - ): - return UnquantizedLinearMethod() - - if not current_omni_platform.is_npu(): - raise NotImplementedError( - "DiffusionMXFP8MXFP4DualScaleConfig is currently only supported on NPU (Ascend) platforms." - ) - - block_idx = _parse_block_idx(prefix) - in_mxfp8_range = block_idx is not None and block_idx < self.num_mxfp8_blocks - - if self.is_checkpoint_serialized: - return NPUMxfp8LinearMethod(self) if in_mxfp8_range else NPUMxfp4DualScaleLinearMethod(self) - else: - return NPUMxfp8OnlineLinearMethod(self) if in_mxfp8_range else NPUMxfp4DualScaleOnlineLinearMethod(self) diff --git a/vllm_omni/quantization/mxfp4_config.py b/vllm_omni/quantization/mxfp4_config.py index 21c48dbdde5..9ce2c0bd53c 100644 --- a/vllm_omni/quantization/mxfp4_config.py +++ b/vllm_omni/quantization/mxfp4_config.py @@ -10,6 +10,13 @@ NPUMxfp4DualScaleLinearMethod – NPU dual-scale offline (W4A4 MXFP4 DualScale) NPUMxfp4DualScaleOnlineLinearMethod – NPU dual-scale online (BF16 → FP4) +Quantization configs: + + DiffusionMXFP4Config – single-scale online/offline (quant_method="mxfp4") + DiffusionMXFP4DualScaleMixedConfig – dual-scale + per-layer BF16 fallback (quant_method="mxfp4_dualscale") + Offline: ignored_layers from config.json routes interleaved BF16 layers + Online: num_bf16_fallback_layers leading blocks stay in BF16 (default 5) + Key differences from MXFP8: 1. Precision: float4_e2m1fn_x2 (FP4 packed, 2 values per element). @@ -42,6 +49,7 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING, Any import torch @@ -556,3 +564,140 @@ def process_weights_after_loading(self, layer: Module) -> None: replace_parameter(layer, "weight_dual_scale", ds) replace_parameter(layer, "mul_scale", ms) layer._already_called_process_weights_after_loading = True + + +# --------------------------------------------------------------------------- +# Block-index helper (shared by DiffusionMXFP4DualScaleMixedConfig) +# --------------------------------------------------------------------------- + +_BLOCK_IDX_RE = re.compile(r"^blocks\.(\d+)\.") + + +def _parse_block_idx(prefix: str) -> int | None: + """Extract block index from prefix like 'blocks.5.attn1.to_q'.""" + m = _BLOCK_IDX_RE.match(prefix) + return int(m.group(1)) if m else None + + +# --------------------------------------------------------------------------- +# Config: MXFP4 DualScale + per-layer BF16 fallback +# --------------------------------------------------------------------------- + + +class DiffusionMXFP4DualScaleMixedConfig(QuantizationConfig): + """W4A4 MXFP4 DualScale with per-layer BF16 fallback for diffusion transformers. + + Sensitive layers fall back to BF16 (original weights) while all other linear + layers use W4A4 MXFP4 DualScale. BF16 fallback layers may be interleaved + anywhere in the transformer. + + Offline mode (is_checkpoint_serialized=True): + Layers whose prefix appears in ignored_layers → UnquantizedLinearMethod (BF16) + All other linear layers → NPUMxfp4DualScaleLinearMethod + + ignored_layers is injected into transformer/config.json by the merge script + and contains the prefixes of all non-MXFP4 linear layers. + + Online mode (is_checkpoint_serialized=False): + Layer routing applies two rules in priority order: + 1. ignored_layers (explicit per-layer BF16 override, user-supplied) → BF16 + 2. Blocks 0 .. num_bf16_fallback_layers-1 (coarse leading-block rule) → BF16 + 3. All other linear layers → NPUMxfp4DualScaleOnlineLinearMethod + + num_bf16_fallback_layers defaults to 5 when not specified. + Set ignored_layers to pin arbitrary interleaved layers to BF16 without + needing an offline checkpoint (useful for accuracy debugging). + Layers outside "blocks.N.*" (condition_embedder etc.) always use online MXFP4 + unless they appear in ignored_layers. + + Config injected by merge_mxfp4_dualscale_checkpoint.py: + { + "quant_method": "mxfp4_dualscale", + "is_checkpoint_serialized": true, + "ignored_layers": ["blocks.0.attn1.to_q", ...] + } + """ + + def __init__( + self, + is_checkpoint_serialized: bool = False, + ignored_layers: list[str] | None = None, + num_bf16_fallback_layers: int = 5, + ) -> None: + super().__init__() + self.is_checkpoint_serialized = is_checkpoint_serialized + self.ignored_layers = ignored_layers or [] + self.num_bf16_fallback_layers = num_bf16_fallback_layers + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "mxfp4_dualscale" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def apply_vllm_mapper(self, hf_to_vllm_mapper: WeightsMapper) -> None: + if self.ignored_layers: + self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) + + @classmethod + def from_config(cls, config: dict[str, Any]) -> DiffusionMXFP4DualScaleMixedConfig: + is_serialized = cls.get_from_keys_or(config, ["is_checkpoint_serialized"], False) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + if not ignored_layers: + ignored_layers = cls.get_from_keys_or(config, ["modules_to_not_convert"], None) + num_bf16_fallback_layers = cls.get_from_keys_or(config, ["num_bf16_fallback_layers"], 5) + return cls( + is_checkpoint_serialized=is_serialized, + ignored_layers=ignored_layers, + num_bf16_fallback_layers=num_bf16_fallback_layers, + ) + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> QuantizeMethodBase | None: + if not isinstance(layer, LinearBase): + return None + + if self.is_checkpoint_serialized: + # Offline: ignored_layers lists interleaved BF16 fallback layer prefixes. + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + if not current_omni_platform.is_npu(): + raise NotImplementedError( + "DiffusionMXFP4DualScaleMixedConfig is currently only supported on NPU (Ascend) platforms." + ) + return NPUMxfp4DualScaleLinearMethod(self) + + # Online: explicit ignored_layers take priority (user-specified per-layer BF16 override), + # then fall back to the coarse leading-block rule. + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + block_idx = _parse_block_idx(prefix) + if block_idx is not None and block_idx < self.num_bf16_fallback_layers: + return UnquantizedLinearMethod() + + if not current_omni_platform.is_npu(): + raise NotImplementedError( + "DiffusionMXFP4DualScaleMixedConfig is currently only supported on NPU (Ascend) platforms." + ) + return NPUMxfp4DualScaleOnlineLinearMethod(self) diff --git a/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py index efe682a8945..298f6101805 100644 --- a/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py +++ b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py @@ -1,22 +1,27 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 -"""Merge mixed MXFP8 + W4A4_MXFP4_DUALSCALE quantized Wan2.2 weights into HF Diffusers format. +"""Merge W4A4_MXFP4_DUALSCALE quantized Wan2.2 weights into HF Diffusers format. -msModelSlim produces a mixed-precision checkpoint where transformer blocks are split into: - - Early blocks (0..num_mxfp8_blocks-1): W8A8_MXFP8 - - Remaining blocks (num_mxfp8_blocks..): W4A4_MXFP4_DUALSCALE +msModelSlim produces a checkpoint where each linear layer is either: + - W4A4_MXFP4_DUALSCALE: quantized weights with fine/coarse scales and mul_scale + - FLOAT (BF16 fallback): kept as original BF16 weights for precision-sensitive layers -MXFP4_DUALSCALE key structure per linear layer ------------------------------------------------ +BF16 fallback layers may be interleaved anywhere in the transformer (not just leading +blocks). The merge script detects them from quant_model_description.json and writes +their prefixes into config.json as ignored_layers so the runtime dynamically routes +each layer to the correct quantization method at weight-loading time. + +MXFP4_DUALSCALE key structure per linear layer (msModelSlim naming) +--------------------------------------------------------------------- blocks.N.X.linear.weight W4A4_MXFP4_DUALSCALE – int8 (FP4 packed) blocks.N.X.linear.weight_scale W4A4_MXFP4_DUALSCALE – uint8 (float8_e8m0fnu fine scale, per-32K) blocks.N.X.linear.weight_dual_scale W4A4_MXFP4_DUALSCALE – float32 (coarse scale, per-512K) blocks.N.X.linear.bias FLOAT – bias (if present) blocks.N.X.div.mul_scale FLOAT – float32 per-input-channel activation pre-scale -MXFP8 key structure (no wrapper, same as merge_mxfp8_checkpoint.py): - blocks.N.X.weight W8A8_MXFP8 - blocks.N.X.weight_scale W8A8_MXFP8 +BF16 fallback key structure (no wrappers, plain linear weights): + blocks.N.X.weight FLOAT + condition_embedder.*.weight FLOAT (always BF16 — not quantized) Self-attention QKV notes ------------------------ @@ -45,8 +50,7 @@ --model-type Wan2.2-T2V-A14B \\ --original-model /path/to/Wan2.2-T2V-A14B-Diffusers \\ --quant-path /path/to/msmodelslim-output \\ - --output-path /path/to/merged-output \\ - --num-mxfp8-blocks 5 # auto-detected if omitted + --output-path /path/to/merged-output """ from __future__ import annotations @@ -56,7 +60,6 @@ import pathlib import re import shutil -import warnings from typing import Any import torch @@ -209,43 +212,6 @@ def _print_block_summary(block_types: dict[int, str]) -> None: print(f" blocks {range_str:>8}: {btype} ({count} block{'s' if count > 1 else ''})") -def _detect_num_mxfp8_blocks(quant_meta: dict[str, str]) -> int: - """Count leading MXFP8 blocks (blocks 0..N-1). - - Two cases handled: - - MXFP8 blocks present in quant_meta (W8A8_MXFP8 markers): - count the consecutive run from block 0. - - MXFP8 blocks absent from quant_meta (msModelSlim may omit them): - the index of the first MXFP4_DUALSCALE block equals num_mxfp8_blocks, - because all blocks before it are implicitly MXFP8. - """ - block_types = _classify_blocks(quant_meta) - if not block_types: - return 0 - - sorted_indices = sorted(block_types) - first_idx = sorted_indices[0] - - if block_types[first_idx] == _MXFP8_TYPE: - # MXFP8 blocks present: count consecutive run starting at block 0. - if first_idx != 0: - warnings.warn( - f"First classified block is {first_idx} (expected 0); " - "cannot determine num_mxfp8_blocks reliably. Returning 0." - ) - return 0 - count = 0 - for idx in sorted_indices: - if block_types[idx] == _MXFP8_TYPE: - count += 1 - else: - break - return count - - # MXFP8 blocks absent from quant_meta: the first MXFP4 block index is the boundary. - return first_idx - - # --------------------------------------------------------------------------- # Safetensors I/O # --------------------------------------------------------------------------- @@ -282,16 +248,80 @@ def _load_quant_meta(directory: pathlib.Path) -> dict[str, str]: # --------------------------------------------------------------------------- +def _is_mxfp4_tensor(key: str, qtype: str) -> bool: + """Return True for MXFP4 quantized tensors and their companion tensors. + + MXFP4 DualScale layers produce four tensor types: + .linear.weight / .linear.weight_scale / .linear.weight_dual_scale (W4A4_MXFP4_DUALSCALE) + .linear.bias / .div.mul_scale (FLOAT but companion to a quantized layer) + + BF16 fallback layers have plain .weight / .bias without .linear. / .div. wrappers. + """ + if qtype.startswith("W4A4_MXFP4_DUALSCALE"): + return True + # Companion tensors: bias and mul_scale that belong to an MXFP4 layer. + return ".linear." in key or ".div.mul_scale" in key + + +def _diffusers_to_vllm_ignored(diffusers_ignored: list[str]) -> list[str]: + """Translate diffusers checkpoint prefixes to vllm-omni model parameter names. + + Three transformations align the naming conventions: + + 1. Self-attention QKV fusion (attn1 only, not attn2): + attn1.to_q + attn1.to_k + attn1.to_v → attn1.to_qkv + Cross-attention (attn2) to_q/k/v are not fused and remain separate. + + 2. FFN naming: + ffn.net.0.proj → ffn.net_0.proj + ffn.net.2 → ffn.net_2 + + 3. to_out index: + attn1.to_out.0 / attn2.to_out.0 → attn1.to_out / attn2.to_out + """ + ignored_set = set(diffusers_ignored) + result: set[str] = set() + + for name in diffusers_ignored: + m = re.match(r"^(.*\.attn1)\.to_([qkv])$", name) + if m: + prefix = m.group(1) + if all(f"{prefix}.to_{c}" in ignored_set for c in ("q", "k", "v")): + result.add(f"{prefix}.to_qkv") + else: + result.add(name) + continue + + name = re.sub(r"\.ffn\.net\.0\.proj$", ".ffn.net_0.proj", name) + name = re.sub(r"\.ffn\.net\.2$", ".ffn.net_2", name) + name = re.sub(r"\.to_out\.0$", ".to_out", name) + result.add(name) + + return sorted(result) + + +def _collect_ignored_layers(merged: dict[str, Any], mxfp4_layer_prefixes: set[str]) -> list[str]: + """Collect vllm-omni parameter name prefixes for BF16 fallback layers. + + A layer is BF16 if it has a .weight tensor but its prefix is not in + mxfp4_layer_prefixes. Prefixes are returned in vllm-omni parameter naming + (QKV-fused, FFN underscored, to_out unindexed) so the list can be written + directly into ignored_layers in config.json without further translation. + """ + all_weight_prefixes = {key[: -len(".weight")] for key in merged if key.endswith(".weight")} + return _diffusers_to_vllm_ignored(sorted(all_weight_prefixes - mxfp4_layer_prefixes)) + + def _convert_transformer( quant_subdir: pathlib.Path, output_dir: pathlib.Path, original_transformer_dir: pathlib.Path, - num_mxfp8_blocks: int | None, -) -> int: - """Convert one transformer directory. Returns the resolved num_mxfp8_blocks.""" +) -> None: + """Convert one transformer directory to the mxfp4_dualscale + BF16 mixed format.""" output_dir.mkdir(parents=True, exist_ok=True) - # BF16 base: ensures non-quantized tensors that msModelSlim might omit are present. + # BF16 base: provides the scaffold for non-MXFP4 tensors (norms, embeddings, + # BF16 fallback linear layers that msModelSlim may omit or keep as FLOAT). print(f" Loading BF16 base from {original_transformer_dir} …") base_state = _load_safetensors_dir(original_transformer_dir) print(f" {len(base_state)} BF16 tensors loaded") @@ -301,54 +331,50 @@ def _convert_transformer( quant_meta = _load_quant_meta(quant_subdir) print(f" {len(quant_state)} quant tensors, {len(quant_meta)} meta entries") - # Classify blocks and auto-detect / validate num_mxfp8_blocks. + # Classify blocks for a compact summary (informational only in the new scheme). block_types = _classify_blocks(quant_meta) - detected = _detect_num_mxfp8_blocks(quant_meta) - # Fill in inferred MXFP8 blocks (may be absent from quant_meta). - for i in range(detected): - block_types.setdefault(i, _MXFP8_TYPE) _print_block_summary(block_types) - if num_mxfp8_blocks is None: - num_mxfp8_blocks = detected - print(f" Auto-detected num_mxfp8_blocks = {num_mxfp8_blocks}") - elif num_mxfp8_blocks != detected: - warnings.warn(f"--num-mxfp8-blocks={num_mxfp8_blocks} but auto-detected {detected}. Using the provided value.") - - # Remap all quantized keys. - # Key transformation is per-block: - # MXFP8 blocks → rename dict only (no .linear./.div. wrappers) - # MXFP4_DUALSCALE blocks → rename dict + strip .linear./.div. wrappers - # Non-block keys → rename dict only + + # Remap MXFP4 quantized tensors only. + # BF16 fallback tensors (FLOAT in quant_meta, no .linear./.div. wrapper) are + # skipped here; the base_state provides them unchanged. remapped: dict[str, torch.Tensor] = {} remapped_meta: dict[str, str] = {} + mxfp4_layer_prefixes: set[str] = set() skipped: list[str] = [] for key, tensor in quant_state.items(): - renamed = _apply_rename_dict(key) + qtype = quant_meta.get(key, "FLOAT") + + if not _is_mxfp4_tensor(key, qtype): + continue # BF16 fallback — base_state already covers this tensor - block_idx = _parse_block_idx(renamed) - if block_idx is not None and block_types.get(block_idx) == _MXFP4_DUALSCALE_TYPE: - final_key = _strip_mxfp4_wrapper(renamed) - else: - final_key = renamed + renamed = _apply_rename_dict(key) + final_key = _strip_mxfp4_wrapper(renamed) - # Skip non-tensor metadata keys that msModelSlim sometimes embeds - # (e.g. quant_type markers stored as scalar tensors). if final_key.endswith(".quant_type"): skipped.append(key) continue remapped[final_key] = tensor - if key in quant_meta: - remapped_meta[final_key] = quant_meta[key] + remapped_meta[final_key] = qtype + + # Track MXFP4 layer prefixes (via their .weight keys) for ignored_layers. + if final_key.endswith(".weight") and qtype.startswith("W4A4_MXFP4_DUALSCALE"): + mxfp4_layer_prefixes.add(final_key[: -len(".weight")]) if skipped: print(f" Skipped {len(skipped)} metadata keys (quant_type markers): {skipped[:5]}") + print(f" {len(remapped)} MXFP4 tensors remapped, {len(mxfp4_layer_prefixes)} quantized layers") - # Overlay: BF16 base provides the scaffold; quant tensors replace their BF16 counterparts + # Merge: base_state provides BF16 scaffold; MXFP4 tensors override their BF16 counterparts # and add the new scale tensors (weight_scale, weight_dual_scale, mul_scale). merged = {**base_state, **remapped} + # Determine ignored_layers: BF16 fallback layer prefixes in vllm-omni parameter naming. + ignored_layers = _collect_ignored_layers(merged, mxfp4_layer_prefixes) + print(f" {len(ignored_layers)} BF16 fallback layers → ignored_layers in config.json") + # Save weights. out_weights = output_dir / "diffusion_pytorch_model.safetensors" save_file(merged, str(out_weights)) @@ -364,26 +390,20 @@ def _convert_transformer( if src_config.is_file(): with open(src_config) as f: config = json.load(f) - config["quantization_config"] = _build_quant_config(num_mxfp8_blocks) + config["quantization_config"] = _build_quant_config(ignored_layers) out_config = output_dir / "config.json" with open(out_config, "w") as f: json.dump(config, f, indent=2) - print( - f" Injected quantization_config " - f"(mxfp8_mxfp4_dualscale, num_mxfp8_blocks={num_mxfp8_blocks}) " - f"→ {out_config}" - ) + print(f" Injected quantization_config (mxfp4_dualscale, {len(ignored_layers)} ignored_layers) → {out_config}") else: print(f" WARNING: No config.json at {src_config}; quantization_config not injected.") - return num_mxfp8_blocks - -def _build_quant_config(num_mxfp8_blocks: int) -> dict[str, Any]: +def _build_quant_config(ignored_layers: list[str]) -> dict[str, Any]: return { - "quant_method": "mxfp8_mxfp4_dualscale", - "num_mxfp8_blocks": num_mxfp8_blocks, + "quant_method": "mxfp4_dualscale", "is_checkpoint_serialized": True, + "ignored_layers": ignored_layers, } @@ -413,7 +433,6 @@ def repack( original_model_path: pathlib.Path, quant_path: pathlib.Path, output_path: pathlib.Path, - num_mxfp8_blocks: int | None, ) -> None: transformer_dirs = _get_transformer_dirs(model_type) @@ -429,9 +448,7 @@ def repack( out_tdir = output_path / tdir orig_tdir = original_model_path / tdir print(f"\nConverting {tdir} (quant source: {q_subdir.name}) …") - resolved_n = _convert_transformer(q_subdir, out_tdir, orig_tdir, num_mxfp8_blocks) - # Use the resolved value for subsequent transformers in the same cascade. - num_mxfp8_blocks = resolved_n + _convert_transformer(q_subdir, out_tdir, orig_tdir) print(f"\nDone. Merged model → {output_path}") print("\nRun inference (quantization auto-detected from config.json):") @@ -452,12 +469,6 @@ def main() -> None: parser.add_argument("--original-model", required=True, help="Original HF Diffusers model directory (BF16).") parser.add_argument("--quant-path", required=True, help="msModelSlim quantized weights directory.") parser.add_argument("--output-path", required=True, help="Output directory for merged model.") - parser.add_argument( - "--num-mxfp8-blocks", - type=int, - default=None, - help=("Number of leading MXFP8 blocks (0..N-1). Auto-detected from quant_model_description.json if omitted."), - ) args = parser.parse_args() repack( @@ -465,7 +476,6 @@ def main() -> None: original_model_path=pathlib.Path(args.original_model), quant_path=pathlib.Path(args.quant_path), output_path=pathlib.Path(args.output_path), - num_mxfp8_blocks=args.num_mxfp8_blocks, ) From d75030ac2b451b4d776137ebfcef47b0472294e7 Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Mon, 18 May 2026 15:23:19 +0800 Subject: [PATCH 5/8] fix merge script Signed-off-by: hyh_hh --- .../quantization/test_mxfp4_key_remap.py | 40 ++++++++++++------- .../tools/merge_mxfp4_dualscale_checkpoint.py | 10 ++++- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/tests/diffusion/quantization/test_mxfp4_key_remap.py b/tests/diffusion/quantization/test_mxfp4_key_remap.py index 89e1a4e0890..e0cbc84dfec 100644 --- a/tests/diffusion/quantization/test_mxfp4_key_remap.py +++ b/tests/diffusion/quantization/test_mxfp4_key_remap.py @@ -213,40 +213,46 @@ def test_is_mxfp4_tensor_norm_weight(): def test_collect_ignored_layers_basic(): from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _collect_ignored_layers + # block 0: all three of Q/K/V are BF16 → fused as to_qkv in ignored_layers + # block 1: to_qkv is MXFP4 → not in ignored_layers merged = { "blocks.0.attn1.to_q.weight": torch.zeros(4, 4), "blocks.0.attn1.to_k.weight": torch.zeros(4, 4), + "blocks.0.attn1.to_v.weight": torch.zeros(4, 4), "blocks.1.attn1.to_q.weight": torch.zeros(4, 4), "blocks.1.attn1.to_q.weight_scale": torch.zeros(4, 1), } mxfp4_prefixes = {"blocks.1.attn1.to_q"} ignored = _collect_ignored_layers(merged, mxfp4_prefixes) - assert "blocks.0.attn1.to_q" in ignored - assert "blocks.0.attn1.to_k" in ignored - assert "blocks.1.attn1.to_q" not in ignored # MXFP4, not ignored + assert "blocks.0.attn1.to_qkv" in ignored # fused vllm-omni name + assert "blocks.0.attn1.to_q" not in ignored # diffusers name must not appear + assert "blocks.0.attn1.to_k" not in ignored + assert "blocks.1.attn1.to_qkv" not in ignored # MXFP4, not ignored def test_collect_ignored_layers_empty_mxfp4(): """When no layers are MXFP4 (all BF16), all .weight prefixes become ignored_layers.""" from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _collect_ignored_layers + # Use non-attn layers to avoid the Q/K/V completeness requirement. merged = { - "blocks.0.attn1.to_q.weight": torch.zeros(4, 4), + "blocks.0.ffn.net.0.proj.weight": torch.zeros(4, 4), "proj_out.weight": torch.zeros(4, 4), } ignored = _collect_ignored_layers(merged, set()) - assert sorted(ignored) == ["blocks.0.attn1.to_q", "proj_out"] + assert sorted(ignored) == ["blocks.0.ffn.net_0.proj", "proj_out"] def test_collect_ignored_layers_returns_sorted(): """ignored_layers must be sorted for deterministic config.json output.""" from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _collect_ignored_layers + # Use FFN layers (no Q/K/V fusion concern) to isolate the sort guarantee. merged = { - "blocks.2.attn1.to_q.weight": torch.zeros(1), - "blocks.0.attn1.to_q.weight": torch.zeros(1), - "blocks.1.attn1.to_q.weight": torch.zeros(1), + "blocks.2.ffn.net.0.proj.weight": torch.zeros(1), + "blocks.0.ffn.net.0.proj.weight": torch.zeros(1), + "blocks.1.ffn.net.0.proj.weight": torch.zeros(1), } ignored = _collect_ignored_layers(merged, set()) assert ignored == sorted(ignored) @@ -287,15 +293,19 @@ def test_remap_self_attn_qkv_fusion(): assert result == ["blocks.0.attn1.to_qkv"] -def test_remap_self_attn_qkv_partial_not_fused(): - """Only some of Q/K/V in ignored — cannot fuse; kept as individual names.""" +def test_remap_self_attn_qkv_partial_raises(): + """Partial Q/K/V in ignored_layers is invalid — must raise ValueError. + + Self-attention Q/K/V are fused into a single to_qkv layer at runtime; + partial precision (some BF16, some MXFP4) cannot be expressed and must + be caught early in the merge script. + """ + import pytest + from vllm_omni.quantization.tools.merge_mxfp4_dualscale_checkpoint import _diffusers_to_vllm_ignored - result = _diffusers_to_vllm_ignored(["blocks.0.attn1.to_q", "blocks.0.attn1.to_k"]) - # Only q and k present, not v → no fusion - assert "blocks.0.attn1.to_qkv" not in result - assert "blocks.0.attn1.to_q" in result - assert "blocks.0.attn1.to_k" in result + with pytest.raises(ValueError, match="Partial BF16 fallback"): + _diffusers_to_vllm_ignored(["blocks.0.attn1.to_q", "blocks.0.attn1.to_k"]) def test_remap_cross_attn_kept_separate(): diff --git a/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py index 298f6101805..0786d460707 100644 --- a/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py +++ b/vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py @@ -289,7 +289,15 @@ def _diffusers_to_vllm_ignored(diffusers_ignored: list[str]) -> list[str]: if all(f"{prefix}.to_{c}" in ignored_set for c in ("q", "k", "v")): result.add(f"{prefix}.to_qkv") else: - result.add(name) + present = [f"to_{c}" for c in ("q", "k", "v") if f"{prefix}.to_{c}" in ignored_set] + missing = [f"to_{c}" for c in ("q", "k", "v") if f"{prefix}.to_{c}" not in ignored_set] + raise ValueError( + f"Partial BF16 fallback for '{prefix}': " + f"{', '.join(present)} in ignored_layers but {', '.join(missing)} is not. " + f"Self-attention Q/K/V are fused into a single to_qkv layer at runtime; " + f"all three must share the same precision. " + f"Either quantize all of to_q/to_k/to_v or keep all three in BF16." + ) continue name = re.sub(r"\.ffn\.net\.0\.proj$", ".ffn.net_0.proj", name) From 435bc0ae5af6de300ed94d5446d3f13a5d4178e6 Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Tue, 19 May 2026 14:38:18 +0800 Subject: [PATCH 6/8] fix input_dim for TP sharding Signed-off-by: hyh_hh --- vllm_omni/quantization/mxfp4_config.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm_omni/quantization/mxfp4_config.py b/vllm_omni/quantization/mxfp4_config.py index 9ce2c0bd53c..cac2792325c 100644 --- a/vllm_omni/quantization/mxfp4_config.py +++ b/vllm_omni/quantization/mxfp4_config.py @@ -388,12 +388,15 @@ def create_weights( ) # Fine scale: one uint8 exponent (float8_e8m0fnu bit pattern) per group of 32 K elements. + # input_dim=1: RowParallelLinear weight_loader shards along the K-group dimension (dim 1) + # so each rank receives only its (K/TP)//32 groups; ColumnParallelLinear only uses + # output_dim=0 for sharding and leaves dim 1 intact. num_groups_fine = (input_size_per_partition + 31) // 32 layer.register_parameter( "weight_scale", ModelWeightParameter( data=torch.empty(output_size_per_partition, num_groups_fine, dtype=torch.uint8), - input_dim=None, + input_dim=1, output_dim=0, weight_loader=weight_loader, ), @@ -402,12 +405,13 @@ def create_weights( # Coarse scale: one float32 per group of 512 K elements. # Shape (N, K_coarse, 1) matches checkpoint layout exactly, avoiding the # shape-mismatch assert in linear.py:1344. + # input_dim=1: same TP sharding rationale as weight_scale above. num_groups_coarse = (input_size_per_partition + 511) // 512 layer.register_parameter( "weight_dual_scale", ModelWeightParameter( data=torch.empty(output_size_per_partition, num_groups_coarse, 1, dtype=torch.float32), - input_dim=None, + input_dim=1, output_dim=0, weight_loader=weight_loader, ), @@ -415,11 +419,13 @@ def create_weights( # mul_scale is a float32 calibration tensor; register as float32 # to avoid precision loss from an implicit BF16 cast during weight loading. + # input_dim=0: RowParallelLinear shards the 1-D per-input-channel tensor along dim 0 + # so each rank receives only its K/TP channels; ColumnParallelLinear leaves it intact. layer.register_parameter( "mul_scale", ModelWeightParameter( data=torch.empty(input_size_per_partition, dtype=torch.float32), - input_dim=None, + input_dim=0, output_dim=None, weight_loader=weight_loader, ), From acde5283298bc4827aa2aeee8955deae6807f867 Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Tue, 19 May 2026 15:35:40 +0800 Subject: [PATCH 7/8] add tp=2 ut Signed-off-by: hyh_hh --- .../quantization/test_mxfp4_config.py | 197 +++++++++++++++++- vllm_omni/quantization/mxfp4_config.py | 2 +- 2 files changed, 196 insertions(+), 3 deletions(-) diff --git a/tests/diffusion/quantization/test_mxfp4_config.py b/tests/diffusion/quantization/test_mxfp4_config.py index 411081c2260..af3a9c9c435 100644 --- a/tests/diffusion/quantization/test_mxfp4_config.py +++ b/tests/diffusion/quantization/test_mxfp4_config.py @@ -3,10 +3,19 @@ """Tests for MXFP4 quantization configs and the MXFP4 DualScale + BF16 mixed config.""" import pytest +import torch pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] +@pytest.fixture(autouse=True) +def _patch_tp_state(monkeypatch): + """Patch TP rank/world_size so ModelWeightParameter can be instantiated on CPU + without an initialized distributed group. Returns TP=1 rank=0 for all tests.""" + monkeypatch.setattr("vllm.model_executor.parameter.get_tensor_model_parallel_rank", lambda: 0) + monkeypatch.setattr("vllm.model_executor.parameter.get_tensor_model_parallel_world_size", lambda: 1) + + # --------------------------------------------------------------------------- # DiffusionMXFP4Config # --------------------------------------------------------------------------- @@ -334,8 +343,6 @@ def test_mixed_dualscale_online_ignored_layers_override( def test_mixed_dualscale_non_linear_returns_none(monkeypatch: pytest.MonkeyPatch): """Non-LinearBase layers (norms, embeddings) must return None → no quantization.""" - import torch - from vllm_omni.platforms import current_omni_platform from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig @@ -344,3 +351,189 @@ def test_mixed_dualscale_non_linear_returns_none(monkeypatch: pytest.MonkeyPatch norm_layer = torch.nn.LayerNorm(64) assert cfg.get_quant_method(norm_layer, "blocks.0.norm1") is None + + +# --------------------------------------------------------------------------- +# TP=2 create_weights: parameter shapes and input_dim/output_dim +# +# Two scenarios mirror real Wan2.2 A14B linear layer types: +# Column-parallel (to_q, ffn.net_0): output is sharded (N/TP), input is full (K). +# Row-parallel (to_out, ffn.net_2): input is sharded (K/TP), output is full (N). +# +# Tests verify: +# 1. Registered parameter shapes are correct for each partition configuration. +# 2. input_dim/output_dim attributes are set so RowParallelLinear.weight_loader +# can shard scale tensors correctly (the fix for the TP>1 shape-mismatch bug). +# 3. Simulated loader slicing: slicing the full checkpoint tensor along the +# declared input_dim produces the exact shape stored in the parameter — +# proving the dim declaration is consistent with the allocation. +# --------------------------------------------------------------------------- + +# K must be divisible by 32 (fine groups) and 512 (coarse groups). +_TP2_K, _TP2_N, _TP2 = 1024, 512, 2 + + +class _FakeLayer(torch.nn.Module): + """Bare nn.Module that accepts register_parameter without a real weight_loader.""" + + +def _create_weights(method, *, input_size_per_partition, output_partition_sizes): + layer = _FakeLayer() + method.create_weights( + layer=layer, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + input_size=_TP2_K, + output_size=_TP2_N, + params_dtype=torch.bfloat16, + ) + return layer + + +def _shard(tensor, param, rank, tp, dim_attr): + """Slice `tensor` along the dimension given by `param.` for `rank`.""" + dim = getattr(param, dim_attr) + if dim is None: + return tensor # not sharded along this axis + shard_size = param.shape[dim] + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(rank * shard_size, (rank + 1) * shard_size) + return tensor[tuple(slices)] + + +# ── DualScale method ───────────────────────────────────────────────────────── + + +def test_dualscale_column_parallel_tp2_shapes(): + """Column-parallel TP=2: output halved, fine/coarse groups stay full, mul_scale full.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig, NPUMxfp4DualScaleLinearMethod + + method = NPUMxfp4DualScaleLinearMethod(DiffusionMXFP4DualScaleMixedConfig()) + layer = _create_weights(method, input_size_per_partition=_TP2_K, output_partition_sizes=[_TP2_N // _TP2]) + + assert layer.weight.shape == (_TP2_N // _TP2, _TP2_K) + assert layer.weight_scale.shape == (_TP2_N // _TP2, _TP2_K // 32) + assert layer.weight_dual_scale.shape == (_TP2_N // _TP2, _TP2_K // 512, 1) + assert layer.mul_scale.shape == (_TP2_K,) + + +def test_dualscale_row_parallel_tp2_shapes(): + """Row-parallel TP=2: input halved, fine/coarse groups halved, mul_scale halved.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig, NPUMxfp4DualScaleLinearMethod + + method = NPUMxfp4DualScaleLinearMethod(DiffusionMXFP4DualScaleMixedConfig()) + layer = _create_weights(method, input_size_per_partition=_TP2_K // _TP2, output_partition_sizes=[_TP2_N]) + + assert layer.weight.shape == (_TP2_N, _TP2_K // _TP2) + assert layer.weight_scale.shape == (_TP2_N, (_TP2_K // _TP2) // 32) + assert layer.weight_dual_scale.shape == (_TP2_N, (_TP2_K // _TP2) // 512, 1) + assert layer.mul_scale.shape == (_TP2_K // _TP2,) + + +def test_dualscale_scale_parameter_input_dims(): + """weight_scale/weight_dual_scale must have input_dim=1; mul_scale must have input_dim=0. + + RowParallelLinear.weight_loader only shards a parameter when input_dim is set. + Without these, loading a full checkpoint tensor into a per-rank shape causes a + shape mismatch for TP>1 row-parallel layers (to_out, ffn.net_2). + """ + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig, NPUMxfp4DualScaleLinearMethod + + method = NPUMxfp4DualScaleLinearMethod(DiffusionMXFP4DualScaleMixedConfig()) + layer = _create_weights(method, input_size_per_partition=_TP2_K, output_partition_sizes=[_TP2_N]) + + assert layer.weight_scale.input_dim == 1 + assert layer.weight_scale.output_dim == 0 + assert layer.weight_dual_scale.input_dim == 1 + assert layer.weight_dual_scale.output_dim == 0 + assert layer.mul_scale.input_dim == 0 + assert layer.mul_scale.output_dim is None + + +def test_dualscale_row_parallel_tp2_loader_simulation(): + """Slicing full checkpoint tensors along input_dim must match row-parallel parameter shapes. + + Simulates what RowParallelLinear.weight_loader does: for each scale parameter, + take the slice at rank*shard_size:(rank+1)*shard_size along input_dim. + The resulting shape must equal the per-rank parameter shape allocated by create_weights. + """ + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig, NPUMxfp4DualScaleLinearMethod + + method = NPUMxfp4DualScaleLinearMethod(DiffusionMXFP4DualScaleMixedConfig()) + layer = _create_weights(method, input_size_per_partition=_TP2_K // _TP2, output_partition_sizes=[_TP2_N]) + + # Full checkpoint tensors (what the loader reads from disk). + ckpt_weight_scale = torch.zeros(_TP2_N, _TP2_K // 32) + ckpt_weight_dual_scale = torch.zeros(_TP2_N, _TP2_K // 512, 1) + ckpt_mul_scale = torch.zeros(_TP2_K) + + for rank in range(_TP2): + assert _shard(ckpt_weight_scale, layer.weight_scale, rank, _TP2, "input_dim").shape == layer.weight_scale.shape + assert ( + _shard(ckpt_weight_dual_scale, layer.weight_dual_scale, rank, _TP2, "input_dim").shape + == layer.weight_dual_scale.shape + ) + assert _shard(ckpt_mul_scale, layer.mul_scale, rank, _TP2, "input_dim").shape == layer.mul_scale.shape + + +def test_dualscale_column_parallel_tp2_loader_simulation(): + """Slicing full checkpoint tensors along output_dim must match column-parallel parameter shapes. + + For column-parallel layers, the loader shards along output_dim (rows). + mul_scale has output_dim=None → not sharded (full tensor, same for all ranks). + """ + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig, NPUMxfp4DualScaleLinearMethod + + method = NPUMxfp4DualScaleLinearMethod(DiffusionMXFP4DualScaleMixedConfig()) + layer = _create_weights(method, input_size_per_partition=_TP2_K, output_partition_sizes=[_TP2_N // _TP2]) + + ckpt_weight_scale = torch.zeros(_TP2_N, _TP2_K // 32) + ckpt_weight_dual_scale = torch.zeros(_TP2_N, _TP2_K // 512, 1) + ckpt_mul_scale = torch.zeros(_TP2_K) + + for rank in range(_TP2): + assert _shard(ckpt_weight_scale, layer.weight_scale, rank, _TP2, "output_dim").shape == layer.weight_scale.shape + assert ( + _shard(ckpt_weight_dual_scale, layer.weight_dual_scale, rank, _TP2, "output_dim").shape + == layer.weight_dual_scale.shape + ) + # mul_scale: output_dim=None → no sharding → full tensor fits the column-parallel parameter + assert _shard(ckpt_mul_scale, layer.mul_scale, rank, _TP2, "output_dim").shape == layer.mul_scale.shape + + +# ── Single-scale method ─────────────────────────────────────────────────────── + + +def test_single_scale_row_parallel_tp2_shapes(): + """Row-parallel TP=2: input halved → weight_scale groups halved.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config, NPUMxfp4LinearMethod + + method = NPUMxfp4LinearMethod(DiffusionMXFP4Config()) + layer = _create_weights(method, input_size_per_partition=_TP2_K // _TP2, output_partition_sizes=[_TP2_N]) + + assert layer.weight.shape == (_TP2_N, _TP2_K // _TP2) + assert layer.weight_scale.shape == (_TP2_N, (_TP2_K // _TP2) // 32) + + +def test_single_scale_scale_parameter_input_dims(): + """Single-scale weight_scale must have input_dim=1 for RowParallel TP sharding.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config, NPUMxfp4LinearMethod + + method = NPUMxfp4LinearMethod(DiffusionMXFP4Config()) + layer = _create_weights(method, input_size_per_partition=_TP2_K, output_partition_sizes=[_TP2_N]) + + assert layer.weight_scale.input_dim == 1 + assert layer.weight_scale.output_dim == 0 + + +def test_single_scale_row_parallel_tp2_loader_simulation(): + """Slicing full checkpoint weight_scale along input_dim matches row-parallel parameter shape.""" + from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4Config, NPUMxfp4LinearMethod + + method = NPUMxfp4LinearMethod(DiffusionMXFP4Config()) + layer = _create_weights(method, input_size_per_partition=_TP2_K // _TP2, output_partition_sizes=[_TP2_N]) + + ckpt_weight_scale = torch.zeros(_TP2_N, _TP2_K // 32) + + for rank in range(_TP2): + assert _shard(ckpt_weight_scale, layer.weight_scale, rank, _TP2, "input_dim").shape == layer.weight_scale.shape diff --git a/vllm_omni/quantization/mxfp4_config.py b/vllm_omni/quantization/mxfp4_config.py index cac2792325c..6a02d5edb87 100644 --- a/vllm_omni/quantization/mxfp4_config.py +++ b/vllm_omni/quantization/mxfp4_config.py @@ -217,7 +217,7 @@ def create_weights( "weight_scale", ModelWeightParameter( data=torch.empty(output_size_per_partition, num_groups, dtype=torch.uint8), - input_dim=None, + input_dim=1, output_dim=0, weight_loader=weight_loader, ), From 0cb3e5c8a91f8339820a8a5f14ef5128a8434be5 Mon Sep 17 00:00:00 2001 From: hyh_hh Date: Wed, 20 May 2026 10:56:06 +0800 Subject: [PATCH 8/8] add user guide Signed-off-by: hyh_hh --- docs/user_guide/quantization/mxfp4.md | 137 ++++++++++++++++++ .../wan2_2/test_create_transformer_quant.py | 47 +++--- .../models/wan2_2/pipeline_wan2_2.py | 106 +------------- vllm_omni/quantization/factory.py | 77 ++++++++++ 4 files changed, 240 insertions(+), 127 deletions(-) diff --git a/docs/user_guide/quantization/mxfp4.md b/docs/user_guide/quantization/mxfp4.md index 608026b3044..2fd4a8395dc 100644 --- a/docs/user_guide/quantization/mxfp4.md +++ b/docs/user_guide/quantization/mxfp4.md @@ -340,3 +340,140 @@ unless they appear in `ignored_layers`. DualScale offline method mitigates this with calibrated `mul_scale` smooth quantization. Use `ignored_layers` and `num_bf16_fallback_layers` to trade off compression vs. accuracy for precision-sensitive layers. + +## Adapting MXFP4 for a New Model + +This section is aimed at developers who want to add MXFP4 support to a model +other than Wan2.2. The three integration points are: (1) discovering the correct +runtime layer names, (2) wiring `ignored_layers` into the model, and (3) writing +a merge script for offline checkpoints. + +### Step 1 — Discover runtime layer names + +`ignored_layers` entries must match the **runtime parameter names** used inside +vllm-omni, which may differ from the names stored in the diffusers checkpoint. +The canonical source of truth is the model's own `named_parameters()`. + +```python +from vllm_omni import Omni + +# Load the model without quantization to inspect parameter names. +omni = Omni(model="/path/to/your-model") # no --quantization flag +for name, _ in omni.pipeline.transformer.named_parameters(): + if "weight" in name and "scale" not in name: + print(name) +``` + +Compare the printed names against the diffusers checkpoint keys +(`safetensors.safe_open` or `torch.load`) to identify any renames your model +applies. Common patterns that differ in Wan2.2 (and may appear in other +models): + +| Diffusers checkpoint name | vllm-omni runtime name | Reason | +|---------------------------|------------------------|--------| +| `attn1.to_q`, `attn1.to_k`, `attn1.to_v` | `attn1.to_qkv` | Self-attention Q/K/V fused into `QKVParallelLinear` | +| `ffn.net.0.proj` | `ffn.net_0.proj` | Dots in sub-module names replaced with underscores | +| `ffn.net.2` | `ffn.net_2` | Same underscore rule | +| `to_out.0` | `to_out` | Sequential index stripped | + +If your model has different fusion patterns, inspect `packed_modules_mapping` +on the model class — this dict records how checkpoint keys are mapped to +fused runtime parameters. + +!!! warning "Partial QKV fallback is not allowed" + If your model fuses Q, K, V into a single layer, `ignored_layers` must + include **all three or none**. A partial fallback (e.g. `to_q` in BF16 but + `to_k`, `to_v` quantized) cannot be expressed at runtime because they share + one `QKVParallelLinear`. The merge script enforces this and raises an error + if only some of the trio appear as non-quantized. + +### Step 2 — Add ignored_layers to the model + +#### Online mode + +Pass `ignored_layers` directly in the quantization config using the **runtime +names** discovered in Step 1. No code changes to the model are required. + +```python +omni = Omni( + model="/path/to/your-model", + quantization={ + "method": "mxfp4_dualscale", + "ignored_layers": [ + "blocks.0.attn1.to_qkv", # runtime name, not diffusers name + "blocks.0.attn1.to_out", + "blocks.0.ffn.net_0.proj", + ], + }, +) +``` + +```bash +# CLI does not support list-typed ignored_layers directly. +# Use the Python API or set ignored_layers in config.json (offline). +python your_script.py --model /path/to/your-model --quantization mxfp4_dualscale +``` + +The `num_bf16_fallback_layers` coarse rule is an alternative to listing layers +individually: set it to N to keep all linear layers in blocks 0 … N-1 in BF16. +The right value depends on the model's sensitivity; evaluate on a validation +set and pick the smallest N that meets your accuracy target. + +#### Offline mode + +For offline checkpoints, `ignored_layers` is written into each transformer's +`config.json` by the merge script (see Step 3). No manual editing is needed if +the merge script is correct. The injected block: + +```json +{ + "quant_method": "mxfp4_dualscale", + "is_checkpoint_serialized": true, + "ignored_layers": [ + "blocks.0.attn1.to_qkv", + "blocks.0.attn1.to_out" + ] +} +``` + +To add a layer manually (e.g. to pin an additional layer to BF16 without +re-running the merge script), edit `config.json` inside the transformer +subfolder. Use runtime names, not diffusers checkpoint names. + +### Step 3 — Write a merge script for offline mode + +The merge script for a new model mirrors +`vllm_omni/quantization/tools/merge_mxfp4_dualscale_checkpoint.py`. The four +things it must do: + +1. **Remap tensor names** from the quantization tool convention to diffusers + convention (strip wrappers like `.linear.`, `.div.`; fix any prefix + differences). + +2. **Collect ignored_layers**: after loading, enumerate all `*.weight` keys that + have no corresponding `*.weight_scale` (i.e. layers the tool left in BF16). + Convert diffusers names to vllm-omni runtime names (fuse QKV, rename FFN + sub-modules, etc.). Write the result to `config.json`. + +3. **Inject `quantization_config`** into `config.json`: + ```python + config["quantization_config"] = { + "quant_method": "mxfp4_dualscale", + "is_checkpoint_serialized": True, + "ignored_layers": ignored_layers, # runtime names + } + ``` + +4. **Save** the merged safetensors and the updated `config.json`. + +The key helper to implement is the diffusers-to-runtime name translator +(equivalent to `_diffusers_to_vllm_ignored` in the Wan2.2 merge script). +For each non-quantized diffusers weight key, apply your model's specific +renaming rules and collect the results. + +!!! tip "Validate before serving" + After producing the offline checkpoint, load it without a `--quantization` + flag and verify that vLLM-Omni auto-detects the correct method. Check that + the layer count reported in the startup log matches expectations: quantized + layer count + `ignored_layers` count should equal total linear layer count. + Any mismatch indicates a name-mapping bug in the merge script. diff --git a/tests/diffusion/models/wan2_2/test_create_transformer_quant.py b/tests/diffusion/models/wan2_2/test_create_transformer_quant.py index 493c8f93555..bb95363fde8 100644 --- a/tests/diffusion/models/wan2_2/test_create_transformer_quant.py +++ b/tests/diffusion/models/wan2_2/test_create_transformer_quant.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Regression tests for transformer quant-config auto-detection and cascade propagation. +"""Regression tests for transformer quant-config auto-detection. The loader path at pipeline_wan2_2.py carries two quantization contracts: @@ -12,9 +12,10 @@ - Rebuilds when the active ignored_layers differs from the disk value Wan22Pipeline._create_transformer (~L456) - - Propagates the auto-detected config to od_config so the second transformer - in a cascade model reuses the same config rather than re-reading independently - - Does NOT overwrite od_config.quantization_config when it is already set + - Passes od_config.quantization_config (set by CLI or externally) to + create_transformer_from_config for each transformer independently. + - When od_config.quantization_config is None, each transformer auto-detects + from its own config.json; od_config is NOT modified by this call. All tests are pure-CPU and do not load model weights. """ @@ -244,17 +245,17 @@ def test_create_transformer_does_not_rebuild_when_ignored_layers_match(monkeypat # --------------------------------------------------------------------------- -# Wan22Pipeline._create_transformer — od_config propagation +# Wan22Pipeline._create_transformer — od_config passthrough # --------------------------------------------------------------------------- -def test_pipeline_create_transformer_propagates_quant_config_to_od_config(monkeypatch): +def test_pipeline_create_transformer_auto_detects_from_config_json(monkeypatch): """When od_config.quantization_config is None, _create_transformer must - auto-detect the quant method from config.json and propagate the built config - back to od_config so the next call can reuse it.""" + auto-detect the quant method from config.json and pass it to the transformer. + od_config itself must remain unchanged (None).""" from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config - FakeTransformer, _ = _make_fake_transformer() + FakeTransformer, captured = _make_fake_transformer() monkeypatch.setattr(wan22_module, "WanTransformer3DModel", FakeTransformer) od_config = SimpleNamespace(quantization_config=None) @@ -269,14 +270,16 @@ def test_pipeline_create_transformer_propagates_quant_config_to_od_config(monkey } pipeline._create_transformer(config) - assert isinstance(od_config.quantization_config, DiffusionMXFP8Config) - assert od_config.quantization_config.is_checkpoint_mxfp8_serialized is True + qc = captured[0].get("quant_config") + assert isinstance(qc, DiffusionMXFP8Config) + assert qc.is_checkpoint_mxfp8_serialized is True + # od_config must NOT be modified — each transformer auto-detects independently. + assert od_config.quantization_config is None def test_pipeline_create_transformer_does_not_overwrite_existing_od_config(monkeypatch): - """If od_config.quantization_config is already set (propagated from the first - transformer), _create_transformer must leave it unchanged — the propagated - config is the authority for the cascade.""" + """If od_config.quantization_config is already set (e.g. via CLI), _create_transformer + must pass it through to the transformer unchanged and leave od_config unmodified.""" from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config FakeTransformer, _ = _make_fake_transformer() @@ -305,8 +308,7 @@ def test_pipeline_create_transformer_does_not_overwrite_existing_od_config(monke def test_pipeline_cascade_both_transformers_get_mxfp8_serialized_config(monkeypatch): """Cascade model (transformer + transformer_2) with MXFP8 checkpoint: - - First transformer: auto-detects serialized config, propagates to od_config. - - Second transformer: reuses the propagated config (same instance). + - Each transformer auto-detects from its own config.json independently. Both must receive is_checkpoint_mxfp8_serialized=True.""" from vllm_omni.quantization.mxfp8_config import DiffusionMXFP8Config @@ -326,19 +328,16 @@ def test_pipeline_cascade_both_transformers_get_mxfp8_serialized_config(monkeypa assert isinstance(qc, DiffusionMXFP8Config), f"transformer[{i}]: expected DiffusionMXFP8Config, got {type(qc)}" assert qc.is_checkpoint_mxfp8_serialized is True, f"transformer[{i}]: expected serialized=True" - # Second transformer must reuse the propagated instance — no unnecessary rebuild. - assert captured[0]["quant_config"] is captured[1]["quant_config"] - def test_pipeline_cascade_mxfp4_dualscale_each_transformer_gets_correct_ignored_layers(monkeypatch): """Cascade with mxfp4_dualscale where transformer and transformer_2 have different ignored_layers in their config.json. Expected outcome: - transformer → ignored_layers=["blocks.0.attn1.to_q"] (auto-detected, propagated to od_config) + transformer → ignored_layers=["blocks.0.attn1.to_q"] (auto-detected from its config.json) transformer_2 → ignored_layers=["blocks.0.attn1.to_q", "blocks.1.attn1.to_q"] - (rebuilt from disk because ignored_layers differ) - od_config → ignored_layers=["blocks.0.attn1.to_q"] (unchanged; rebuild was local) + (auto-detected from its own config.json independently) + od_config → quantization_config remains None (not modified by _create_transformer) """ from vllm_omni.quantization.mxfp4_config import DiffusionMXFP4DualScaleMixedConfig @@ -380,5 +379,5 @@ def test_pipeline_cascade_mxfp4_dualscale_each_transformer_gets_correct_ignored_ "blocks.1.attn1.to_q", }, f"transformer_2 expected 2 layers, got {qc2.ignored_layers}" - # od_config retains the first transformer's config; the rebuild was local. - assert set(od_config.quantization_config.ignored_layers) == {"blocks.0.attn1.to_q"} + # od_config must remain unchanged — _create_transformer does not modify it. + assert od_config.quantization_config is None diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 665124727ff..3ee46ffb003 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -121,21 +121,6 @@ def load_transformer_config(model_path: str, subfolder: str = "transformer", loc return {} -_SERIALIZED_FLAGS = ( - "is_checkpoint_mxfp8_serialized", - "is_checkpoint_mxfp4_serialized", - "is_checkpoint_serialized", -) - - -def _disk_marks_serialized(qc_kwargs: dict, quant_config: object) -> bool: - """Return True when config.json says serialized but the active quant_config does not.""" - for flag in _SERIALIZED_FLAGS: - if qc_kwargs.get(flag, False) and hasattr(quant_config, flag) and not getattr(quant_config, flag): - return True - return False - - def create_transformer_from_config( config: dict, quant_config: QuantizationConfig | None = None, @@ -174,63 +159,10 @@ def create_transformer_from_config( if "pos_embed_seq_len" in config: kwargs["pos_embed_seq_len"] = config["pos_embed_seq_len"] - # Auto-detect quantization from transformer's config.json when not explicitly provided. - # merge_mxfp8_checkpoint.py injects quantization_config into config.json so that - # offline quantized checkpoints are recognized here without a CLI flag. if "quantization_config" in config: - from vllm_omni.quantization.factory import build_quant_config - - disk_qc = config["quantization_config"] - if isinstance(disk_qc, dict) and "quant_method" in disk_qc: - qc_method = disk_qc["quant_method"] - qc_kwargs = {k: v for k, v in disk_qc.items() if k != "quant_method"} - if quant_config is None: - # No CLI flag: full auto-detection. - quant_config = build_quant_config(qc_method, **qc_kwargs) - logger.info( - "Auto-detected quantization from transformer config.json: method=%s kwargs=%s", - qc_method, - qc_kwargs, - ) - elif quant_config.get_name() != qc_method: - # The caller supplied a quant_config of a different method than - # what the checkpoint was built with. Loading serialized tensors - # (e.g. MXFP8 weight scales) with the wrong linear method would - # produce corrupt output or a shape mismatch crash. Reject early - # so the user gets a clear message instead of a silent failure. - raise ValueError( - f"Checkpoint config.json declares quant_method={qc_method!r} but the " - f"active quantization config is {quant_config.get_name()!r}. " - "Pass a matching --quantization flag or omit it for auto-detection." - ) - elif _disk_marks_serialized(qc_kwargs, quant_config): - # Same method: CLI provided online mode but config.json marks this - # as a pre-quantized offline checkpoint. Switch to offline mode so - # users can pass --quantization mxfp8 without knowing the - # online/offline distinction. - quant_config = build_quant_config(qc_method, **qc_kwargs) - logger.info( - "config.json marks checkpoint as serialized; switching to offline %s mode.", - qc_method, - ) - elif ( - "ignored_layers" in qc_kwargs - and hasattr(quant_config, "ignored_layers") - and set(qc_kwargs.get("ignored_layers") or []) != set(quant_config.ignored_layers or []) - ): - # mxfp4_dualscale cascade: each transformer may have different BF16 fallback - # layers in its config.json. Rebuild so per-transformer routing is correct. - quant_config = build_quant_config(qc_method, **qc_kwargs) - logger.info( - "Disk config.json ignored_layers differs from active config; " - "rebuilding quant_config from transformer config.json.", - ) - elif isinstance(disk_qc, str) and quant_config is None: - quant_config = build_quant_config(disk_qc) - logger.info( - "Auto-detected quantization from transformer config.json: method=%s", - disk_qc, - ) + from vllm_omni.quantization.factory import resolve_quant_config_from_disk + + quant_config = resolve_quant_config_from_disk(quant_config, config["quantization_config"]) if quant_config is not None: kwargs["quant_config"] = quant_config @@ -457,38 +389,6 @@ def __init__( def _create_transformer(self, config: dict) -> WanTransformer3DModel: """Create a transformer from a config dict. Respects od_config.quantization_config.""" quant_config = getattr(self.od_config, "quantization_config", None) - - # When od_config.quantization_config is None (no CLI --quantization flag), pre-build - # the quant_config from the transformer's own config.json and propagate it back to - # od_config. This has two effects: - # 1. The first transformer's auto-detected config is reused by the second transformer - # in cascade models (e.g. Wan2.2-T2V-A14B); if the second transformer's config.json - # has different ignored_layers, create_transformer_from_config rebuilds locally. - # 2. od_config.quantization_config becomes non-None so _check_unloaded_weights can - # filter expected quantization suffixes instead of raising on every unloaded param. - if quant_config is None and "quantization_config" in config: - from vllm_omni.quantization.factory import build_quant_config - - disk_qc = config["quantization_config"] - if isinstance(disk_qc, dict) and "quant_method" in disk_qc: - qc_method = disk_qc["quant_method"] - qc_kwargs = {k: v for k, v in disk_qc.items() if k != "quant_method"} - quant_config = build_quant_config(qc_method, **qc_kwargs) - self.od_config.quantization_config = quant_config - logger.info( - "Auto-detected quantization from transformer config.json and propagated to od_config: " - "method=%s kwargs=%s", - qc_method, - qc_kwargs, - ) - elif isinstance(disk_qc, str): - quant_config = build_quant_config(disk_qc) - self.od_config.quantization_config = quant_config - logger.info( - "Auto-detected quantization from transformer config.json and propagated to od_config: method=%s", - disk_qc, - ) - return create_transformer_from_config(config, quant_config=quant_config) @property diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py index acdb1eb4586..955f97cef85 100644 --- a/vllm_omni/quantization/factory.py +++ b/vllm_omni/quantization/factory.py @@ -280,3 +280,80 @@ def build_quant_config( return _build_from_method_and_config(method, merged) raise TypeError(f"quantization config must be str, dict, QuantizationConfig, or None, got {type(spec).__name__}") + + +def _disk_marks_serialized(qc_kwargs: dict[str, Any], quant_config: object) -> bool: + """Return True when config.json says serialized but the active quant_config does not. + + Matches any flag following the is_checkpoint_*_serialized naming convention, + so new quant methods don't require updating an explicit allowlist. + """ + for key, val in qc_kwargs.items(): + if key.startswith("is_checkpoint_") and key.endswith("_serialized"): + if val and hasattr(quant_config, key) and not getattr(quant_config, key): + return True + return False + + +def resolve_quant_config_from_disk( + quant_config: QuantizationConfig | None, + disk_qc: dict[str, Any] | str | None, +) -> QuantizationConfig | None: + """Reconcile an active quant_config against quantization_config from a transformer's config.json. + + Used when loading individual transformer blocks that each have their own config.json + (e.g. cascade models with separate transformer and transformer_2 directories). + + Rules: + - disk_qc is None: return quant_config unchanged. + - quant_config is None: auto-detect from disk_qc (full build). + - Methods mismatch: raise ValueError — prevents silent weight corruption. + - Disk marks serialized but quant_config is online: rebuild from disk. + - ignored_layers differ: rebuild from disk (per-transformer BF16 routing). + """ + if disk_qc is None: + return quant_config + + if isinstance(disk_qc, str): + if quant_config is None: + logger.info("Auto-detected quantization from config.json: method=%s", disk_qc) + return build_quant_config(disk_qc) + return quant_config + + if not isinstance(disk_qc, Mapping) or "quant_method" not in disk_qc: + return quant_config + + qc_method: str = disk_qc["quant_method"] + qc_kwargs: dict[str, Any] = {k: v for k, v in disk_qc.items() if k != "quant_method"} + + if quant_config is None: + logger.info( + "Auto-detected quantization from config.json: method=%s kwargs=%s", + qc_method, + qc_kwargs, + ) + return build_quant_config(qc_method, **qc_kwargs) + + if quant_config.get_name() != qc_method: + raise ValueError( + f"Checkpoint config.json declares quant_method={qc_method!r} but the " + f"active quantization config is {quant_config.get_name()!r}. " + "Pass a matching --quantization flag or omit it for auto-detection." + ) + + if _disk_marks_serialized(qc_kwargs, quant_config): + logger.info( + "config.json marks checkpoint as serialized; switching to offline %s mode.", + qc_method, + ) + return build_quant_config(qc_method, **qc_kwargs) + + if ( + "ignored_layers" in qc_kwargs + and hasattr(quant_config, "ignored_layers") + and set(qc_kwargs.get("ignored_layers") or []) != set(quant_config.ignored_layers or []) + ): + logger.info("config.json ignored_layers differs from active config; rebuilding quant_config.") + return build_quant_config(qc_method, **qc_kwargs) + + return quant_config