diff --git a/.buildkite/test-nightly.yml b/.buildkite/test-nightly.yml index 1b61044affa..ae93d2353a3 100644 --- a/.buildkite/test-nightly.yml +++ b/.buildkite/test-nightly.yml @@ -959,7 +959,7 @@ steps: - label: ":full_moon: Diffusion X2V ยท Function Test" timeout_in_minutes: 90 commands: - - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py tests/e2e/online_serving/test_wan_2_1_vace_expansion.py tests/e2e/online_serving/test_hunyuan_video_15_expansion.py -m "full_model and cuda" --run-level "full_model" + - pytest -s -v tests/e2e/online_serving/test_wan22_expansion.py tests/e2e/online_serving/test_wan_2_1_vace_expansion.py tests/e2e/online_serving/test_hunyuan_video_15_expansion.py tests/e2e/offline_inference/test_wan22_autoround_w4a16_expansion.py -m "full_model and cuda" --run-level "full_model" agents: queue: "mithril-h100-pool" plugins: diff --git a/docs/user_guide/quantization/autoround.md b/docs/user_guide/quantization/autoround.md index 2261d79a57c..88fed3b62b3 100644 --- a/docs/user_guide/quantization/autoround.md +++ b/docs/user_guide/quantization/autoround.md @@ -32,7 +32,9 @@ guide. AutoRound is Intel-supported. |-------|------------|-------|--------|---------| | FLUX.1-dev | `vllm-project-org/FLUX.1-dev-AutoRound-w4a16` | Diffusion transformer | W4A16 | GPTQ-Marlin or Intel-supported AutoRound backend | | Qwen-Image | Not listed | Diffusion transformer | W4A16 | Not validated | -| Wan2.2 | Not listed | Diffusion transformer | W4A16 | Not validated | +| Wan2.2-I2V | `Intel/Wan2.2-I2V-A14B-Diffusers-int4-AutoRound` | Diffusion transformer | W4A16 | GPTQ-Marlin or Intel-supported AutoRound backend | +| Wan2.2-T2V | `Intel/Wan2.2-T2V-A14B-Diffusers-int4-AutoRound` | Diffusion transformer | W4A16 | GPTQ-Marlin or Intel-supported AutoRound backend | +| Wan2.2-TI2V | `Intel/Wan2.2-TI2V-5B-Diffusers-int4-AutoRound` | Diffusion transformer | W4A16 | GPTQ-Marlin or Intel-supported AutoRound backend | ### Multi-Stage Omni/TTS Model (Qwen3-Omni, Qwen3-TTS) diff --git a/docs/user_guide/quantization/mxfp4.md b/docs/user_guide/quantization/mxfp4.md index 7463ada23ee..401a55ad4d8 100644 --- a/docs/user_guide/quantization/mxfp4.md +++ b/docs/user_guide/quantization/mxfp4.md @@ -397,7 +397,7 @@ names** discovered in Step 1. No code changes to the model are required. ```python omni = Omni( model="/path/to/your-model", - quantization_config={ + quantization={ "method": "mxfp4_dualscale", "ignored_layers": [ "blocks.0.attn1.to_qkv", # runtime name, not diffusers name diff --git a/tests/diffusion/models/wan2_2/test_wan22_quant_config_propagation.py b/tests/diffusion/models/wan2_2/test_wan22_quant_config_propagation.py new file mode 100644 index 00000000000..b1405993312 --- /dev/null +++ b/tests/diffusion/models/wan2_2/test_wan22_quant_config_propagation.py @@ -0,0 +1,301 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for Wan2.2 quant_config propagation through transformer creation. + +Tests cover: +- create_transformer_from_config passes quant_config and prefix +- create_vace_transformer_from_config passes quant_config and prefix +- set_tf_model_config propagates quant_config to OmniDiffusionConfig +- patch_wan_rms_norm safely iterates sys.modules with concurrent modifications +- I2V transformer_2 quant_config is built from config dict +""" + +import sys +from types import SimpleNamespace + +import pytest +from pytest_mock import MockerFixture + +import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 as wan22_module +import vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace as wan22_vace_module +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( + create_transformer_from_config, +) +from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2_vace import ( + create_vace_transformer_from_config, +) + +pytestmark = [pytest.mark.core_model, pytest.mark.cpu, pytest.mark.diffusion] + + +# --------------------------------------------------------------------------- +# create_transformer_from_config: quant_config / prefix forwarding +# --------------------------------------------------------------------------- + + +class TestCreateTransformerQuant: + """Verify quant_config and prefix are forwarded to WanTransformer3DModel.""" + + def test_quant_config_passed_through(self, mocker: MockerFixture): + captured = {} + + class FakeTransformer: + def __init__(self, **kwargs): + captured.update(kwargs) + + mocker.patch.object(wan22_module, "WanTransformer3DModel", FakeTransformer) + + fake_qc = mocker.MagicMock() + create_transformer_from_config( + {"patch_size": [1, 2, 2], "num_layers": 2}, + quant_config=fake_qc, + ) + assert captured.get("quant_config") is fake_qc + + def test_prefix_passed_through(self, mocker: MockerFixture): + captured = {} + + class FakeTransformer: + def __init__(self, **kwargs): + captured.update(kwargs) + + mocker.patch.object(wan22_module, "WanTransformer3DModel", FakeTransformer) + + create_transformer_from_config( + {"patch_size": [1, 2, 2]}, + prefix="model.transformer.", + ) + assert captured.get("prefix") == "model.transformer." + + def test_quant_config_none_by_default(self, mocker: MockerFixture): + captured = {} + + class FakeTransformer: + def __init__(self, **kwargs): + captured.update(kwargs) + + mocker.patch.object(wan22_module, "WanTransformer3DModel", FakeTransformer) + + create_transformer_from_config({"patch_size": [1, 2, 2]}) + # When quant_config is None and prefix is "", they are not added + assert "quant_config" not in captured or captured["quant_config"] is None + + def test_quant_config_and_prefix_together(self, mocker: MockerFixture): + captured = {} + + class FakeTransformer: + def __init__(self, **kwargs): + captured.update(kwargs) + + mocker.patch.object(wan22_module, "WanTransformer3DModel", FakeTransformer) + + fake_qc = mocker.MagicMock() + create_transformer_from_config( + {"patch_size": [1, 2, 2], "num_attention_heads": 4}, + quant_config=fake_qc, + prefix="blocks.", + ) + assert captured["quant_config"] is fake_qc + assert captured["prefix"] == "blocks." + + +# --------------------------------------------------------------------------- +# create_vace_transformer_from_config: quant_config / prefix forwarding +# --------------------------------------------------------------------------- + + +class TestCreateVaceTransformerQuant: + """Verify quant_config and prefix are forwarded to WanVACETransformer3DModel.""" + + def test_quant_config_passed_through(self, mocker: MockerFixture): + captured = {} + + class FakeVACETransformer: + def __init__(self, **kwargs): + captured.update(kwargs) + + mocker.patch.object(wan22_vace_module, "WanVACETransformer3DModel", FakeVACETransformer) + + fake_qc = mocker.MagicMock() + create_vace_transformer_from_config( + {"patch_size": [1, 2, 2], "num_layers": 2}, + quant_config=fake_qc, + ) + assert captured.get("quant_config") is fake_qc + + def test_prefix_passed_through(self, mocker: MockerFixture): + captured = {} + + class FakeVACETransformer: + def __init__(self, **kwargs): + captured.update(kwargs) + + mocker.patch.object(wan22_vace_module, "WanVACETransformer3DModel", FakeVACETransformer) + + create_vace_transformer_from_config( + {"patch_size": [1, 2, 2]}, + prefix="vace.", + ) + assert captured.get("prefix") == "vace." + + +# --------------------------------------------------------------------------- +# set_tf_model_config: propagation of quant_config +# --------------------------------------------------------------------------- + + +class TestSetTfModelConfig: + """Test that set_tf_model_config propagates quant_config correctly.""" + + def _make_od_config(self): + """Create a minimal OmniDiffusionConfig-like object for testing.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + cfg = object.__new__(OmniDiffusionConfig) + cfg.quantization_config = None + cfg.tf_model_config = None + return cfg + + def test_propagates_quant_config_when_none(self, mocker: MockerFixture): + cfg = self._make_od_config() + fake_qc = mocker.MagicMock() + tf_config = SimpleNamespace(quant_config=fake_qc, quant_method="auto-round") + + cfg.set_tf_model_config(tf_config) + + assert cfg.tf_model_config is tf_config + assert cfg.quantization_config is fake_qc + + def test_does_not_overwrite_existing_quantization_config(self, mocker: MockerFixture): + cfg = self._make_od_config() + existing_qc = mocker.MagicMock() + cfg.quantization_config = existing_qc + tf_config = SimpleNamespace(quant_config=mocker.MagicMock()) + + cfg.set_tf_model_config(tf_config) + + assert cfg.tf_model_config is tf_config + assert cfg.quantization_config is existing_qc # not overwritten + + def test_no_propagation_when_tf_quant_config_is_none(self, mocker: MockerFixture): + cfg = self._make_od_config() + tf_config = SimpleNamespace(quant_config=None) + + cfg.set_tf_model_config(tf_config) + + assert cfg.tf_model_config is tf_config + assert cfg.quantization_config is None + + +# --------------------------------------------------------------------------- +# patch_wan_rms_norm: sys.modules snapshot safety +# --------------------------------------------------------------------------- + + +class TestPatchWanRmsNorm: + """Test that patch_wan_rms_norm doesn't raise on concurrent module registration.""" + + def test_patches_modules_with_wan_rms_norm(self): + from vllm_omni.diffusion.layers.norm import RMSNormVAE + from vllm_omni.diffusion.models.wan2_2.patch_diffusers import patch_wan_rms_norm + + # Create a fake module that has WanRMS_norm + fake_module = SimpleNamespace(WanRMS_norm=lambda x: x) + sys.modules["_test_fake_wan_module"] = fake_module + + try: + patch_wan_rms_norm() + assert fake_module.WanRMS_norm is RMSNormVAE + finally: + del sys.modules["_test_fake_wan_module"] + + def test_no_error_when_modules_change_during_iteration(self): + """Regression test: list() snapshot prevents RuntimeError.""" + from vllm_omni.diffusion.models.wan2_2.patch_diffusers import patch_wan_rms_norm + + # Simulate a module being added during iteration by a side effect + original_items = sys.modules.items + + def items_with_side_effect(): + # This would cause RuntimeError without list() snapshot + result = list(original_items()) + # Add a new module to simulate concurrent modification + sys.modules["_test_dynamic_module"] = SimpleNamespace() + return result + + try: + # The function uses list(sys.modules.items()) so it takes a snapshot + # Just verify it doesn't raise + patch_wan_rms_norm() + finally: + sys.modules.pop("_test_dynamic_module", None) + + +# --------------------------------------------------------------------------- +# I2V transformer_2 quant_config extraction +# --------------------------------------------------------------------------- + + +class TestI2VTransformer2QuantConfig: + """Test the transformer_2 quant_config build logic from pipeline_wan2_2_i2v.""" + + def test_transformer_2_quant_config_built_from_dict(self): + """When transformer_2 config has quantization_config dict, build_quant_config is called.""" + from vllm_omni.quantization.factory import build_quant_config + + t2_config = { + "patch_size": [1, 2, 2], + "num_layers": 2, + "quantization_config": { + "quant_method": "auto-round", + "bits": 4, + "group_size": 128, + "sym": True, + "packing_format": "auto_round:auto_gptq", + }, + } + + # Replicate the logic from pipeline_wan2_2_i2v.py + t2_quant = t2_config.get("quantization_config") + if isinstance(t2_quant, dict) and "quant_method" in t2_quant: + method = t2_quant["quant_method"] + kwargs = {k: v for k, v in t2_quant.items() if k != "quant_method"} + t2_quant = build_quant_config(method, **kwargs) + else: + t2_quant = None + + from vllm.model_executor.layers.quantization.inc import INCConfig + + assert isinstance(t2_quant, INCConfig) + assert t2_quant.weight_bits == 4 + assert t2_quant.group_size == 128 + + def test_transformer_2_quant_config_none_when_missing(self): + """When transformer_2 config has no quantization_config, result is None.""" + t2_config = { + "patch_size": [1, 2, 2], + "num_layers": 2, + } + + t2_quant = t2_config.get("quantization_config") + if isinstance(t2_quant, dict) and "quant_method" in t2_quant: + pass # won't enter + else: + t2_quant = None + + assert t2_quant is None + + def test_transformer_2_quant_config_none_when_dict_lacks_method(self): + """When quantization_config is a dict but missing quant_method, result is None.""" + t2_config = { + "patch_size": [1, 2, 2], + "quantization_config": {"bits": 4}, # no quant_method key + } + + t2_quant = t2_config.get("quantization_config") + if isinstance(t2_quant, dict) and "quant_method" in t2_quant: + pass + else: + t2_quant = None + + assert t2_quant is None diff --git a/tests/e2e/offline_inference/test_wan22_autoround_w4a16_expansion.py b/tests/e2e/offline_inference/test_wan22_autoround_w4a16_expansion.py new file mode 100644 index 00000000000..044435cb571 --- /dev/null +++ b/tests/e2e/offline_inference/test_wan22_autoround_w4a16_expansion.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""E2E tests for Wan2.2 AutoRound W4A16 quantized inference. + +These tests cover I2V (image-to-video) and T2V (text-to-video) generation +with quantized weights. + +Requirements: + - CUDA GPU (H100 or equivalent, ~36 GiB for quantized model) + - The quantized model checkpoint (Intel/Wan2.2-I2V-A14B-Diffusers-int4-AutoRound, + Intel/Wan2.2-T2V-A14B-Diffusers-int4-AutoRound) +""" + +import gc +import os as _os + +import numpy as np +import pytest +import torch +from PIL import Image + +from tests.helpers.env import DeviceMemoryMonitor +from tests.helpers.mark import hardware_test +from vllm_omni.inputs.data import OmniDiffusionSamplingParams +from vllm_omni.platforms import current_omni_platform + +_os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +QUANTIZED_MODEL_I2V = "Intel/Wan2.2-I2V-A14B-Diffusers-int4-AutoRound" +BASELINE_MODEL_I2V = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" +QUANTIZED_MODEL_T2V = "Intel/Wan2.2-T2V-A14B-Diffusers-int4-AutoRound" +BASELINE_MODEL_T2V = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" + +QUANTIZED_MODEL_I2V = _os.environ.get("WAN22_I2V_AUTOROUND_MODEL", QUANTIZED_MODEL_I2V) +BASELINE_MODEL_I2V = _os.environ.get("WAN22_I2V_BASELINE_MODEL", BASELINE_MODEL_I2V) +QUANTIZED_MODEL_T2V = _os.environ.get("WAN22_T2V_AUTOROUND_MODEL", QUANTIZED_MODEL_T2V) +BASELINE_MODEL_T2V = _os.environ.get("WAN22_T2V_BASELINE_MODEL", BASELINE_MODEL_T2V) + +pytestmark = [ + pytest.mark.full_model, + pytest.mark.diffusion, +] + +# Small resolution to keep GPU memory & time manageable +HEIGHT = 480 +WIDTH = 640 +NUM_FRAMES = 5 # must satisfy num_frames % 4 == 1 for Wan2.2 +NUM_STEPS = 2 # minimal for smoke-test + +# Parametrise: (model, stage_config_path=None, extra_omni_kwargs) +# When stage_config_path is None, the engine auto-resolves from the model's own config. +quant_i2v_params = [(QUANTIZED_MODEL_I2V, None, {"enforce_eager": True})] +baseline_i2v_params = [(BASELINE_MODEL_I2V, None, {"enforce_eager": True})] +quant_t2v_params = [(QUANTIZED_MODEL_T2V, None, {"enforce_eager": True})] +baseline_t2v_params = [(BASELINE_MODEL_T2V, None, {"enforce_eager": True})] + +# Module-level storage for peak memory results across tests +_memory_results: dict[str, float] = {} + + +def _sampling_params_i2v() -> OmniDiffusionSamplingParams: + """Create sampling params for I2V generation.""" + return OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=NUM_STEPS, + guidance_scale=5.0, + guidance_scale_2=6.0, + boundary_ratio=0.875, + generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42), + ) + + +def _sampling_params_t2v() -> OmniDiffusionSamplingParams: + """Create sampling params for T2V generation.""" + return OmniDiffusionSamplingParams( + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + num_inference_steps=NUM_STEPS, + guidance_scale=4.0, + generator=torch.Generator(device=current_omni_platform.device_type).manual_seed(42), + ) + + +def _create_test_image(width: int = WIDTH, height: int = HEIGHT) -> Image.Image: + """Create a deterministic test image for I2V tests.""" + rng = np.random.RandomState(42) + arr = rng.randint(0, 256, (height, width, 3), dtype=np.uint8) + return Image.fromarray(arr) + + +def _generate_i2v_video(omni_runner_handler, prompt: str = "A cat sitting on a table, smooth motion") -> tuple: + """Generate one I2V video, return (frames, peak_memory_mb).""" + gc.collect() + current_omni_platform.empty_cache() + device_index = current_omni_platform.current_device() + current_omni_platform.reset_peak_memory_stats() + monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + + image = _create_test_image() + response = omni_runner_handler.send_diffusion_request( + { + "prompt": prompt, + "images": image, + "sampling_params": _sampling_params_i2v(), + }, + ) + + peak = monitor.peak_used_mb + monitor.stop() + + assert response.success, f"Request failed: {response.error_message}" + assert response.images is not None and len(response.images) > 0, "Expected image output" + frames = response.images[0] + + gc.collect() + current_omni_platform.empty_cache() + + return frames, peak + + +def _generate_t2v_video(omni_runner_handler, prompt: str = "A cat sitting on a table") -> tuple: + """Generate one T2V video, return (frames, peak_memory_mb).""" + gc.collect() + current_omni_platform.empty_cache() + device_index = current_omni_platform.current_device() + current_omni_platform.reset_peak_memory_stats() + monitor = DeviceMemoryMonitor(device_index=device_index, interval=0.02) + monitor.start() + + response = omni_runner_handler.send_diffusion_request( + { + "prompt": prompt, + "sampling_params": _sampling_params_t2v(), + }, + ) + + peak = monitor.peak_used_mb + monitor.stop() + + assert response.success, f"Request failed: {response.error_message}" + assert response.images is not None and len(response.images) > 0, "Expected image output" + frames = response.images[0] + + gc.collect() + current_omni_platform.empty_cache() + + return frames, peak + + +# ------------------------------------------------------------------ +# Test: I2V quantized model generates valid video +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +@pytest.mark.parametrize("omni_runner", quant_i2v_params, indirect=True) +def test_wan22_i2v_autoround_w4a16_generates_video(omni_runner, omni_runner_handler): + """Load the W4A16 quantized Wan2.2 I2V model and verify it produces a valid video.""" + frames, _ = _generate_i2v_video(omni_runner_handler) + + assert frames is not None, "Expected video frames output" + assert hasattr(frames, "shape"), "Expected frames to have a shape attribute" + + # frames shape: (batch, num_frames, height, width, channels) + assert frames.shape[1] == NUM_FRAMES, f"Expected {NUM_FRAMES} frames, got {frames.shape[1]}" + assert frames.shape[2] == HEIGHT, f"Expected height {HEIGHT}, got {frames.shape[2]}" + assert frames.shape[3] == WIDTH, f"Expected width {WIDTH}, got {frames.shape[3]}" + + # Sanity: video should not be blank (frames are [0, 1] floats) + arr = np.asarray(frames) + assert arr.std() > 0.01, "Generated video appears blank (std โ‰ˆ 0)" + + +# ------------------------------------------------------------------ +# Test: T2V quantized model generates valid video +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +@pytest.mark.parametrize("omni_runner", quant_t2v_params, indirect=True) +def test_wan22_t2v_autoround_w4a16_generates_video(omni_runner, omni_runner_handler): + """Load the W4A16 quantized Wan2.2 T2V model and verify it produces a valid video.""" + frames, _ = _generate_t2v_video(omni_runner_handler) + + assert frames is not None, "Expected video frames output" + assert hasattr(frames, "shape"), "Expected frames to have a shape attribute" + + assert frames.shape[1] == NUM_FRAMES, f"Expected {NUM_FRAMES} frames, got {frames.shape[1]}" + assert frames.shape[2] == HEIGHT, f"Expected height {HEIGHT}, got {frames.shape[2]}" + assert frames.shape[3] == WIDTH, f"Expected width {WIDTH}, got {frames.shape[3]}" + + arr = np.asarray(frames) + assert arr.std() > 0.01, "Generated video appears blank (std โ‰ˆ 0)" + + +# ------------------------------------------------------------------ +# Test: I2V quantized peak memory +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +@pytest.mark.parametrize("omni_runner", quant_i2v_params, indirect=True) +def test_wan22_i2v_autoround_w4a16_quant_peak(omni_runner, omni_runner_handler): + """Measure peak GPU memory of W4A16 quantized I2V model.""" + frames, peak = _generate_i2v_video(omni_runner_handler) + + assert frames is not None, "Expected video frames output" + _memory_results["quant_i2v"] = peak + print(f"\nQuantized I2V (W4A16) peak memory: {peak:.0f} MB") + + +# ------------------------------------------------------------------ +# Test: I2V baseline peak memory +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +@pytest.mark.parametrize("omni_runner", baseline_i2v_params, indirect=True) +def test_wan22_i2v_autoround_w4a16_baseline_peak(omni_runner, omni_runner_handler): + """Measure peak GPU memory of BF16 baseline I2V model.""" + frames, peak = _generate_i2v_video(omni_runner_handler) + + assert frames is not None, "Expected video frames output" + _memory_results["baseline_i2v"] = peak + print(f"\nBaseline I2V (BF16) peak memory: {peak:.0f} MB") + + +# ------------------------------------------------------------------ +# Test: I2V memory savings +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +def test_wan22_i2v_autoround_w4a16_memory_savings(): + """Assert quantized I2V model uses meaningfully less memory than BF16 baseline.""" + quant_peak = _memory_results["quant_i2v"] + baseline_peak = _memory_results["baseline_i2v"] + + savings = baseline_peak - quant_peak + print(f"\nQuantized I2V (W4A16) peak memory: {quant_peak:.0f} MB") + print(f"Baseline I2V (BF16) peak memory: {baseline_peak:.0f} MB") + print(f"Savings: {savings:.0f} MB") + + # Wan2.2 I2V A14B transformer is ~28 GB in BF16; W4A16 should save ~20 GB. + # Use a conservative threshold to account for activations and overhead. + min_savings_mb = 5000 + assert quant_peak + min_savings_mb < baseline_peak, ( + f"Quantized model ({quant_peak:.0f} MB) should use at least " + f"{min_savings_mb} MB less than baseline ({baseline_peak:.0f} MB)" + ) + + +# ------------------------------------------------------------------ +# Test: T2V quantized peak memory +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +@pytest.mark.parametrize("omni_runner", quant_t2v_params, indirect=True) +def test_wan22_t2v_autoround_w4a16_quant_peak(omni_runner, omni_runner_handler): + """Measure peak GPU memory of W4A16 quantized T2V model.""" + frames, peak = _generate_t2v_video(omni_runner_handler) + + assert frames is not None, "Expected video frames output" + _memory_results["quant_t2v"] = peak + print(f"\nQuantized T2V (W4A16) peak memory: {peak:.0f} MB") + + +# ------------------------------------------------------------------ +# Test: T2V baseline peak memory +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +@pytest.mark.parametrize("omni_runner", baseline_t2v_params, indirect=True) +def test_wan22_t2v_autoround_w4a16_baseline_peak(omni_runner, omni_runner_handler): + """Measure peak GPU memory of BF16 baseline T2V model.""" + frames, peak = _generate_t2v_video(omni_runner_handler) + + assert frames is not None, "Expected video frames output" + _memory_results["baseline_t2v"] = peak + print(f"\nBaseline T2V (BF16) peak memory: {peak:.0f} MB") + + +# ------------------------------------------------------------------ +# Test: T2V memory savings +# ------------------------------------------------------------------ + + +@hardware_test(res={"cuda": "H100"}) +def test_wan22_t2v_autoround_w4a16_memory_savings(): + """Assert quantized T2V model uses meaningfully less memory than BF16 baseline.""" + quant_peak = _memory_results["quant_t2v"] + baseline_peak = _memory_results["baseline_t2v"] + + savings = baseline_peak - quant_peak + print(f"\nQuantized T2V (W4A16) peak memory: {quant_peak:.0f} MB") + print(f"Baseline T2V (BF16) peak memory: {baseline_peak:.0f} MB") + print(f"Savings: {savings:.0f} MB") + + # Wan2.2 T2V A14B transformer is ~28 GB in BF16; W4A16 should save ~20 GB. + # Use a conservative threshold to account for activations and overhead. + min_savings_mb = 5000 + assert quant_peak + min_savings_mb < baseline_peak, ( + f"Quantized model ({quant_peak:.0f} MB) should use at least " + f"{min_savings_mb} MB less than baseline ({baseline_peak:.0f} MB)" + ) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index 3ee46ffb003..2d8c752a4eb 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -122,8 +122,7 @@ def load_transformer_config(model_path: str, subfolder: str = "transformer", loc def create_transformer_from_config( - config: dict, - quant_config: QuantizationConfig | None = None, + config: dict, quant_config: QuantizationConfig | None = None, prefix: str = "" ) -> WanTransformer3DModel: """Create WanTransformer3DModel from config dict.""" kwargs: dict = {} @@ -166,6 +165,8 @@ def create_transformer_from_config( if quant_config is not None: kwargs["quant_config"] = quant_config + if prefix: + kwargs["prefix"] = prefix return WanTransformer3DModel(**kwargs) diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index 42c4eff6add..5ff3742051f 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -231,10 +231,25 @@ def __init__( # Transformers (weights loaded via load_weights) # Load config from model directory or HF Hub to get correct in_channels for I2V models transformer_config = load_transformer_config(model, "transformer", local_files_only) - self.transformer = self._create_transformer(transformer_config) + self.transformer = create_transformer_from_config( + transformer_config, + quant_config=od_config.quantization_config, + ) if self.has_transformer_2: transformer_2_config = load_transformer_config(model, "transformer_2", local_files_only) - self.transformer_2 = self._create_transformer(transformer_2_config) + t2_quant = transformer_2_config.get("quantization_config") + if isinstance(t2_quant, dict) and "quant_method" in t2_quant: + from vllm_omni.quantization.factory import build_quant_config + + method = t2_quant["quant_method"] + kwargs = {k: v for k, v in t2_quant.items() if k != "quant_method"} + t2_quant = build_quant_config(method, **kwargs) + else: + t2_quant = None + self.transformer_2 = create_transformer_from_config( + transformer_2_config, + quant_config=t2_quant, + ) else: self.transformer_2 = None diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py index 75bdac27f2a..5ba2c6c690f 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py @@ -45,6 +45,7 @@ def create_vace_transformer_from_config( config: dict, quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> WanVACETransformer3DModel: """Create WanVACETransformer3DModel from config dict.""" kwargs = {} @@ -84,6 +85,8 @@ def create_vace_transformer_from_config( kwargs["vace_in_channels"] = config["vace_in_channels"] if quant_config is not None: kwargs["quant_config"] = quant_config + if prefix: + kwargs["prefix"] = prefix return WanVACETransformer3DModel(**kwargs) diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py index 955f97cef85..3766e4596cd 100644 --- a/vllm_omni/quantization/factory.py +++ b/vllm_omni/quantization/factory.py @@ -99,6 +99,43 @@ def _build_inc(**kw: Any) -> QuantizationConfig: SUPPORTED_QUANTIZATION_METHODS: list[str] = list(dict.fromkeys(QUANTIZATION_METHODS + list(_OVERRIDES.keys()))) +def _build_reverse_alias_map() -> dict[str, str]: + """Build a mapping from normalized method aliases to canonical names. + + All keys in _OVERRIDES that share the same builder function are considered + aliases of each other. The canonical name is the first key (in definition + order) that maps to a given builder โ€” i.e. the one returned by + builder().get_name(). + """ + builder_to_first_key: dict[Callable[..., QuantizationConfig], str] = {} + for key in _OVERRIDES: + builder = _OVERRIDES[key] + if builder not in builder_to_first_key: + builder_to_first_key[builder] = key + + result: dict[str, str] = {} + for key, builder in _OVERRIDES.items(): + canonical = builder_to_first_key[builder] + result[key.lower().replace("-", "_")] = canonical + return result + + +_CACHED_ALIAS_MAP: dict[str, str] | None = None + + +def _normalize_quant_method_alias(method: str | None) -> str | None: + """Map a method name (or any of its aliases) to its canonical internal name. + Returns the input unchanged if it is not a known alias. + """ + if method is None: + return None + global _CACHED_ALIAS_MAP + if _CACHED_ALIAS_MAP is None: + _CACHED_ALIAS_MAP = _build_reverse_alias_map() + normalized = method.lower().replace("-", "_") + return _CACHED_ALIAS_MAP.get(normalized, normalized) + + _MODEL_OPT_METHODS = { "modelopt", } @@ -334,7 +371,9 @@ def resolve_quant_config_from_disk( ) return build_quant_config(qc_method, **qc_kwargs) - if quant_config.get_name() != qc_method: + active_method = _normalize_quant_method_alias(quant_config.get_name()) + disk_method = _normalize_quant_method_alias(qc_method) + if active_method != disk_method: raise ValueError( f"Checkpoint config.json declares quant_method={qc_method!r} but the " f"active quantization config is {quant_config.get_name()!r}. "