diff --git a/docs/.nav.yml b/docs/.nav.yml index 8042ffb706f..b7d08e77e91 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -58,6 +58,7 @@ nav: - Quantization: - Overview: user_guide/diffusion/quantization/overview.md - FP8: user_guide/diffusion/quantization/fp8.md + - Int8: user_guide/diffusion/quantization/int8.md - GGUF: user_guide/diffusion/quantization/gguf.md - Parallelism Acceleration: user_guide/diffusion/parallelism_acceleration.md - CPU Offloading: user_guide/diffusion/cpu_offload_diffusion.md diff --git a/docs/user_guide/diffusion/quantization/int8.md b/docs/user_guide/diffusion/quantization/int8.md new file mode 100644 index 00000000000..1e7853c3eb4 --- /dev/null +++ b/docs/user_guide/diffusion/quantization/int8.md @@ -0,0 +1,75 @@ +# Int8 Quantization + +## Overview + +Int8 quantization converts BF16/FP16 weights to Int8 at model load time. No calibration or pre-quantized checkpoint needed. + +Depending on the model, either all layers can be quantized, or some sensitive layers should stay in BF16/FP16. See the [per-model table](#supported-models) for which case applies. + +## Configuration + +1. **Python API**: set `quantization="int8"`. To skip sensitive layers, use `quantization_config` with `ignored_layers`. + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +# All layers quantized +omni = Omni(model="", quantization="int8") + +# Skip sensitive layers +omni = Omni( + model="", + quantization_config={ + "method": "int8", + "ignored_layers": [""], + }, +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + +2. **CLI**: pass `--quantization int8` and optionally `--ignored-layers`. + +```bash +# All layers +python text_to_image.py --model --quantization int8 + +# Skip sensitive layers +python text_to_image.py --model --quantization int8 --ignored-layers "img_mlp" + +# Online serving +vllm serve --omni --quantization int8 +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `method` | str | — | Quantization method (`"int8"`) | +| `ignored_layers` | list[str] | `[]` | Layer name patterns to keep in BF16/FP16 | +| `activation_scheme` | str | `"dynamic"` | `"dynamic"` (no calibration) | + + +The available `ignored_layers` names depend on the model architecture (e.g., `to_qkv`, `to_out`, `img_mlp`, `txt_mlp`). Consult the transformer source for your target model. + +## Supported Models + +| Model | HF Models | Recommendation | `ignored_layers` | +|-------|-----------|---------------|------------------| +| Z-Image | `Tongyi-MAI/Z-Image-Turbo` | All layers | None | +| Qwen-Image | `Qwen/Qwen-Image`, `Qwen/Qwen-Image-2512` | All layers | None | + +## Combining with Other Features + +Int8 quantization can be combined with cache acceleration: + +```python +omni = Omni( + model="", + quantization="int8", + cache_backend="tea_cache", + cache_config={"rel_l1_thresh": 0.2}, +) +``` diff --git a/docs/user_guide/diffusion/quantization/overview.md b/docs/user_guide/diffusion/quantization/overview.md index e4ce69677c3..a95afdbf498 100644 --- a/docs/user_guide/diffusion/quantization/overview.md +++ b/docs/user_guide/diffusion/quantization/overview.md @@ -7,12 +7,20 @@ vLLM-Omni supports quantization of DiT linear layers to reduce memory usage and | Method | Guide | |--------|-------| | FP8 | [FP8](fp8.md) | +| Int8 | [Int8](int8.md) | | GGUF | [GGUF](gguf.md) | -## Device Compatibility +## Device Compatibility for FP8 | GPU Generation | Example GPUs | FP8 Mode | |---------------|-------------------|----------| | Ada/Hopper (SM 89+) | RTX 4090, H100, H200 | Full W8A8 with native hardware | Kernel selection is automatic. + +## Device Compatibility for Int8 + +| Device Type | Generation | Example | Int8 Mode | +|-------------|---------------|-------------------|----------| +| NVIDIA GPU | Ada/Hopper (SM 89+) | RTX 4090, H100, H200 | Full W8A8 with native hardware | +| Ascend NPU | Atlas A2/Atlas A3 | Atlas 800T A2/Atlas 900 A3 | Full W8A8 with native hardware | diff --git a/docs/user_guide/diffusion_acceleration.md b/docs/user_guide/diffusion_acceleration.md index e9179b5752a..7ef47075692 100644 --- a/docs/user_guide/diffusion_acceleration.md +++ b/docs/user_guide/diffusion_acceleration.md @@ -16,7 +16,7 @@ Both methods can provide significant speedups (typically **1.5x-2.0x**) while ma vLLM-Omni also supports quantization methods: -3. **[FP8 Quantization](diffusion/quantization/overview.md)** - Reduces DiT linear layers from BF16 to FP8, providing ~1.28x speedup with minimal quality loss. Supports per-layer skip for sensitive layers. +3. **[Quantization](diffusion/quantization/overview.md)** - Reduces DiT linear layers from BF16 to FP8 or Int8, providing ~1.28x speedup with minimal quality loss. Supports per-layer skip for sensitive layers. vLLM-Omni also supports parallelism methods for diffusion models, including: @@ -46,6 +46,7 @@ vLLM-Omni also supports parallelism methods for diffusion models, including: | Method | Configuration | Description | Best For | |--------|--------------|-------------|----------| | **FP8** | `quantization="fp8"` | FP8 W8A8 on Ada/Hopper, weight-only on older GPUs | Memory reduction, inference speedup | +| **Int8** | `quantization="int8"` | Int8 W8A8 | Memory reduction, inference speedup | ## Supported Models @@ -81,11 +82,11 @@ The following table shows which models are currently supported by each accelerat ### Quantization -| Model | Model Identifier | FP8 | -|-------|------------------|:---:| -| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | -| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | -| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | +| Model | Model Identifier | FP8 | Int8 | +|-------|------------------|:---:|:---:| +| **Qwen-Image** | `Qwen/Qwen-Image` | ✅ | ✅ | +| **Qwen-Image-2512** | `Qwen/Qwen-Image-2512` | ✅ | ✅ | +| **Z-Image** | `Tongyi-MAI/Z-Image-Turbo` | ✅ | ✅ | ## Performance Benchmarks @@ -338,13 +339,30 @@ outputs = omni.generate( ) ``` +### Using Int8 Quantization + +```python +from vllm_omni import Omni +from vllm_omni.inputs.data import OmniDiffusionSamplingParams + +omni = Omni( + model="", + quantization="int8", +) + +outputs = omni.generate( + "A cat sitting on a windowsill", + OmniDiffusionSamplingParams(num_inference_steps=50), +) +``` + ## Documentation For detailed information on each acceleration method: - **[TeaCache Guide](diffusion/teacache.md)** - Complete TeaCache documentation, configuration options, and best practices - **[Cache-DiT Acceleration Guide](diffusion/cache_dit_acceleration.md)** - Comprehensive Cache-DiT guide covering DBCache, TaylorSeer, SCM, and configuration parameters -- **[FP8 Quantization Guide](diffusion/quantization/overview.md)** - FP8 quantization for DiT models with per-layer control +- **[Quantization Guide](diffusion/quantization/overview.md)** - Quantization for DiT models with per-layer control - **[Tensor Parallelism](diffusion/parallelism_acceleration.md#tensor-parallelism)** - Guidance on how to enable TP for diffusion models. - **[Sequence Parallelism](diffusion/parallelism_acceleration.md#sequence-parallelism)** - Guidance on how to set sequence parallelism with configuration. - **[CFG-Parallel](diffusion/parallelism_acceleration.md#cfg-parallel)** - Guidance on how to set CFG-Parallel to run positive/negative branches across ranks. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index e2463619a30..ea17e6858c5 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -131,12 +131,10 @@ def parse_args() -> argparse.Namespace: "--quantization", type=str, default=None, - choices=["fp8", "gguf"], - help=( - "Quantization method for the transformer. " - "Options: 'fp8' (FP8 W8A8), 'gguf' (GGUF quantized weights). " - "Default: None (no quantization, uses BF16)." - ), + choices=["fp8", "int8", "gguf"], + help="Quantization method for the transformer. " + "Options: 'fp8' (FP8 W8A8 on Ada/Hopper, weight-only on older GPUs), 'int8' (Int8 W8A8), 'gguf' (GGUF quantized weights). " + "Default: None (no quantization, uses BF16).", ) parser.add_argument( "--gguf-model", diff --git a/tests/diffusion/quantization/test_int8_config.py b/tests/diffusion/quantization/test_int8_config.py new file mode 100644 index 00000000000..9b5d67fcbaf --- /dev/null +++ b/tests/diffusion/quantization/test_int8_config.py @@ -0,0 +1,463 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for Int8 quantization config.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch +from pytest_mock import MockerFixture +from torch.nn import Module, Parameter +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod + +from vllm_omni.diffusion.quantization import ( + get_diffusion_quant_config, + get_vllm_quant_config_for_layers, +) +from vllm_omni.platforms import current_omni_platform + +npu_available = pytest.mark.skipif(not current_omni_platform.is_npu(), reason="NPU platform not available.") + +cuda_available = pytest.mark.skipif(not current_omni_platform.is_cuda(), reason="GPU platform not available.") + + +def test_int8_config_creation(): + """Test that Int8 config can be created.""" + config = get_diffusion_quant_config("int8") + assert config is not None + assert config.get_name() == "int8" + + +def test_vllm_config_extraction(): + """Test that vLLM config can be extracted from diffusion config.""" + diff_config = get_diffusion_quant_config("int8") + vllm_config = get_vllm_quant_config_for_layers(diff_config) + assert vllm_config is not None + assert vllm_config.activation_scheme == "dynamic" + + +def test_none_quantization(): + """Test that None quantization returns None config.""" + config = get_diffusion_quant_config(None) + assert config is None + vllm_config = get_vllm_quant_config_for_layers(config) + assert vllm_config is None + + +def test_invalid_quantization(): + """Test that invalid quantization method raises error.""" + with pytest.raises(ValueError, match="Unknown quantization method"): + get_diffusion_quant_config("invalid_method") + + +def test_int8_config_with_custom_params(): + """Test Int8 config with custom parameters.""" + config = get_diffusion_quant_config( + "int8", + activation_scheme="dynamic", + ignored_layers=["proj_out"], + ) + assert config is not None + assert config.activation_scheme == "dynamic" + assert "proj_out" in config.ignored_layers + + +def test_supported_methods(): + """Test that supported methods list is correct.""" + from vllm_omni.diffusion.quantization import SUPPORTED_QUANTIZATION_METHODS + + assert "int8" in SUPPORTED_QUANTIZATION_METHODS + + +def test_quantization_integration(): + """Test end-to-end quantization flow through OmniDiffusionConfig.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + # Test with quantization string only + config = OmniDiffusionConfig(model="test", quantization="int8") + assert config.quantization_config is not None + assert config.quantization_config.get_name() == "int8" + + # Test with quantization_config dict + config2 = OmniDiffusionConfig( + model="test", + quantization_config={"method": "int8", "activation_scheme": "dynamic"}, + ) + assert config2.quantization_config is not None + assert config2.quantization_config.get_name() == "int8" + assert config2.quantization_config.activation_scheme == "dynamic" + + # Test that vLLM config can be extracted + vllm_config = config.quantization_config.get_vllm_quant_config() + assert vllm_config is not None + + +def test_quantization_dict_not_mutated(): + """Test that passing a dict to quantization_config doesn't mutate it.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + original_dict = {"method": "int8", "activation_scheme": "dynamic"} + dict_copy = original_dict.copy() + + OmniDiffusionConfig(model="test", quantization_config=original_dict) + + # Original dict should be unchanged + assert original_dict == dict_copy + + +def test_quantization_conflicting_methods_warning(caplog): + """Test warning when quantization and quantization_config['method'] conflict.""" + import logging + + from vllm_omni.diffusion.data import OmniDiffusionConfig + + with caplog.at_level(logging.WARNING): + config = OmniDiffusionConfig( + model="test", + quantization="int8", # This should be overridden + quantization_config={"method": "int8", "activation_scheme": "dynamic"}, + ) + # No warning when methods match + assert config.quantization_config is not None + + +def test_get_quant_method(mocker: MockerFixture): + """Test for get_quant_method method for GPU""" + from vllm_omni.diffusion.quantization.int8 import Int8OnlineLinearMethod + + diff_config = get_diffusion_quant_config("int8") + vllm_config = get_vllm_quant_config_for_layers(diff_config) + + def _fake_init(self, quant_config): + pass + + layer = MagicMock(spec=LinearBase) + mocker.patch.object(Int8OnlineLinearMethod, "__init__", _fake_init) + + prefix = "test_layer" + + # Mock the platform to be GPU + with ( + patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=True), + patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=False), + ): + method = vllm_config.get_quant_method(layer, prefix) + assert isinstance(method, Int8OnlineLinearMethod) + + # Test skipping quantization for a layer + vllm_config.ignored_layers = [prefix] + method = vllm_config.get_quant_method(layer, prefix) + assert isinstance(method, UnquantizedLinearMethod) + + +def test_get_npu_quant_method(): + """Test for get_quant_method method for NPU""" + from vllm_omni.diffusion.quantization.int8 import NPUInt8OnlineLinearMethod + + diff_config = get_diffusion_quant_config("int8") + vllm_config = get_vllm_quant_config_for_layers(diff_config) + + layer = MagicMock(spec=LinearBase) + prefix = "test_layer" + + # Mock the platform to be NPU + with ( + patch("vllm_omni.platforms.current_omni_platform.is_cuda", return_value=False), + patch("vllm_omni.platforms.current_omni_platform.is_npu", return_value=True), + ): + method = vllm_config.get_quant_method(layer, prefix) + assert isinstance(method, NPUInt8OnlineLinearMethod) + + # Test skipping quantization for a layer + vllm_config.ignored_layers = [prefix] + method = vllm_config.get_quant_method(layer, prefix) + assert isinstance(method, UnquantizedLinearMethod) + + +class TestInt8LinearMethod: + @pytest.fixture + def mock_quant_config(self, mocker): + return mocker.Mock() + + @pytest.fixture + def mock_kernel(self, mocker): + kernel = mocker.Mock() + kernel.process_weights_after_loading = mocker.Mock() + kernel.apply_weights = mocker.Mock(return_value=torch.randn(1, 10)) + return kernel + + @pytest.fixture + def patch_deps(self, mocker, mock_kernel): + # mock init_int8_linear_kernel + mocker.patch("vllm_omni.diffusion.quantization.int8.init_int8_linear_kernel", return_value=mock_kernel) + return mock_kernel + + def test_init(self, patch_deps, mock_quant_config): + # test for Int8LinearMethod init + from vllm_omni.diffusion.quantization.int8 import Int8LinearMethod, init_int8_linear_kernel + + method = Int8LinearMethod(mock_quant_config) + + assert method.quant_config == mock_quant_config + init_int8_linear_kernel.assert_called_once_with( + is_channelwise=False, is_static_input_scheme=False, input_symmetric=True, module_name="Int8LinearMethod" + ) + assert method.int8_linear == patch_deps + + def test_process_weights_after_loading(self, patch_deps, mock_quant_config): + from vllm_omni.diffusion.quantization.int8 import Int8LinearMethod + + method = Int8LinearMethod(mock_quant_config) + layer = Module() + + method.process_weights_after_loading(layer) + patch_deps.process_weights_after_loading.assert_called_once_with(layer) + + def test_apply(self, patch_deps, mock_quant_config): + from vllm_omni.diffusion.quantization.int8 import Int8LinearMethod + + method = Int8LinearMethod(mock_quant_config) + layer = Module() + x = torch.randn(1, 128) + bias = torch.randn(128) + + output = method.apply(layer, x, bias) + + patch_deps.apply_weights.assert_called_once_with(layer, x, bias) + assert isinstance(output, torch.Tensor) + + +class TestInt8OnlineLinearMethod: + @pytest.fixture + def mock_quant_config(self, mocker): + return mocker.Mock() + + @pytest.fixture + def mock_deps(self, mocker): + # mock kernel + mock_kernel = mocker.Mock() + mock_kernel.layer_param_names = ("weight", "weight_scale", "input_scale", "input_zero_point", "azp_adj") + mocker.patch("vllm_omni.diffusion.quantization.int8.init_int8_linear_kernel", return_value=mock_kernel) + mocker.patch("vllm_omni.diffusion.quantization.int8.replace_parameter") + + # mock scaled_int8_quant return value + mock_qweight = torch.ones((128, 64), dtype=torch.int8) + mock_scale = torch.randn(128) + mock_quant = mocker.patch( + "vllm_omni.diffusion.quantization.int8.ops.scaled_int8_quant", return_value=(mock_qweight, mock_scale, None) + ) + return {"kernel": mock_kernel, "quant": mock_quant, "mock_qweight": mock_qweight, "mock_scale": mock_scale} + + def test_process_weights_after_loading(self, mock_deps, mock_quant_config): + from vllm_omni.diffusion.quantization.int8 import Int8OnlineLinearMethod + + method = Int8OnlineLinearMethod(mock_quant_config) + layer = Module() + layer.weight = Parameter(torch.randn(128, 64)) + method.process_weights_after_loading(layer) + mock_deps["quant"].assert_called_once_with(layer.weight, scale=None) + + +@npu_available +class TestNPUInt8LinearMethod: + qweight_mock = torch.randn((128, 64)).to(dtype=torch.int8) + scale_mock = torch.randn(128) + out_mock = torch.randn((16, 128)) + + @pytest.fixture + def mock_torch_npu(self, mocker): + torch_npu = MagicMock() + + mocker.patch("vllm_omni.diffusion.quantization.int8.torch_npu", return_value=torch_npu) + mocker.patch( + "vllm_omni.diffusion.quantization.int8.torch_npu.npu_dynamic_quant", + return_value=(self.qweight_mock, self.scale_mock), + ) + mocker.patch("vllm_omni.diffusion.quantization.int8.torch_npu.npu_quant_matmul", return_value=self.out_mock) + return torch_npu + + @pytest.fixture + def mock_quant_config(self, mocker): + return mocker.Mock() + + @pytest.fixture + def mock_layer(self, mocker): + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter(self.qweight_mock, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(self.scale_mock, requires_grad=False) + return layer + + def test_npu_int8_process_weights_after_loading(self, mock_layer, mock_quant_config, mock_torch_npu): + from vllm_omni.diffusion.quantization.int8 import NPUInt8LinearMethod + + method = NPUInt8LinearMethod(mock_quant_config) + ori_weight_shape = mock_layer.weight.shape + + method.process_weights_after_loading(mock_layer) + + assert mock_layer.weight.shape == ori_weight_shape[::-1] + assert mock_layer.weight.is_contiguous() + + def test_npu_int8_apply(self, mock_layer, mock_quant_config, mock_torch_npu): + from vllm_omni.diffusion.quantization.int8 import NPUInt8LinearMethod + + method = NPUInt8LinearMethod(mock_quant_config) + x = torch.randn(1, 16, 64) + + output = method.apply(mock_layer, x) + assert output.shape == (1, 16, 128) + + def test_npu_int8_online_process_weights(self, mock_layer, mock_quant_config, mock_torch_npu): + from vllm_omni.diffusion.quantization.int8 import NPUInt8OnlineLinearMethod + + method = NPUInt8OnlineLinearMethod(mock_quant_config) + method.process_weights_after_loading(mock_layer) + + assert mock_layer.weight.shape == (64, 128) + assert torch.equal(mock_layer.weight_scale, self.scale_mock) + + +@pytest.fixture +def quant_config(): + """Shared quant config fixture for smoke tests.""" + from vllm_omni.diffusion.quantization.int8 import Int8Config + + return Int8Config( + is_checkpoint_int8_serialized=False, + activation_scheme="dynamic", + ) + + +@npu_available +class TestNPUInt8Smoke: + """Smoke tests using real torch_npu, only run on NPU.""" + + @pytest.fixture + def real_layer(self): + """Create a real linear layer with fp16 weights on NPU""" + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter( + torch.randn(128, 64, dtype=torch.float16, device="npu"), + requires_grad=False, + ) + layer.logical_widths = [128] + layer.input_size_per_partition = 64 + layer.output_size_per_partition = 128 + layer.orig_dtype = torch.float16 + return layer + + def test_real_npu_dynamic_quant_shape_contract(self, quant_config, real_layer): + """Smoke test: verify npu_dynamic_quant returns correct shapes.""" + import torch_npu + + # Call real torch_npu.npu_dynamic_quant + weight = real_layer.weight + qweight, scale = torch_npu.npu_dynamic_quant(weight) + + assert qweight.shape == weight.shape + assert qweight.dtype == torch.int8 + assert scale.shape == (weight.shape[0],) + + def test_real_npu_online_process_weights_after_loading(self, quant_config, real_layer): + """Smoke test: full process_weights_after_loading with real torch_npu.""" + from vllm_omni.diffusion.quantization.int8 import NPUInt8OnlineLinearMethod + + method = NPUInt8OnlineLinearMethod(quant_config) + + method.process_weights_after_loading(real_layer) + + assert real_layer.weight.shape == (64, 128) + assert real_layer.weight.dtype == torch.int8 + assert hasattr(real_layer, "weight_scale") + assert real_layer.weight_scale.shape == (128,) + + def test_real_npu_int8_apply_forward(self, quant_config): + """Smoke test: forward pass with real npu_quant_matmul.""" + import torch_npu + + from vllm_omni.diffusion.quantization.int8 import NPUInt8LinearMethod + + method = NPUInt8LinearMethod(quant_config) + + # Create layer with pre-processed weights on NPU + layer = torch.nn.Module() + weight_fp16 = torch.randn(128, 64, dtype=torch.float16, device="npu") + qweight, scale = torch_npu.npu_dynamic_quant(weight_fp16) + layer.weight = torch.nn.Parameter(qweight.t().contiguous(), requires_grad=False) + layer.weight_scale = torch.nn.Parameter(scale.squeeze(), requires_grad=False) + + # Forward pass on NPU + x = torch.randn(2, 16, 64, dtype=torch.float16, device="npu") + output = method.apply(layer, x) + + assert output.shape == (2, 16, 128) + assert output.dtype == torch.float16 + + +@cuda_available +class TestCudaInt8Smoke: + """Smoke tests using real CUDA kernels, only on CUDA""" + + @pytest.fixture + def real_layer(self): + """Create a real linear layer with fp16 weights on CUDA""" + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter( + torch.randn(128, 64, dtype=torch.float16, device="cuda"), + requires_grad=False, + ) + layer.logical_widths = [128] + layer.input_size_per_partition = 64 + layer.output_size_per_partition = 128 + layer.orig_dtype = torch.float16 + return layer + + def test_real_cuda_scaled_int8_quant_shape_contract(self, quant_config): + """Smoke test: verify scaled_int8_quant returns correct shapes.""" + from vllm import _custom_ops as ops + + weight = torch.randn(128, 64, dtype=torch.float16, device="cuda") + qweight, scale, _ = ops.scaled_int8_quant(weight, scale=None) + + assert qweight.shape == weight.shape + assert qweight.dtype == torch.int8 + assert scale.shape == (weight.shape[0], 1) + + def test_real_cuda_online_process_weights_after_loading(self, quant_config, real_layer): + """Smoke test: full process_weights_after_loading with real CUDA ops.""" + from vllm_omni.diffusion.quantization.int8 import Int8OnlineLinearMethod + + method = Int8OnlineLinearMethod(quant_config) + + method.process_weights_after_loading(real_layer) + + assert real_layer.weight.shape == (64, 128) + assert real_layer.weight.dtype == torch.int8 + assert hasattr(real_layer, "weight_scale") + + def test_real_cuda_int8_apply_forward(self, quant_config): + """Smoke test: forward pass with real CUDA int8 kernel.""" + from vllm import _custom_ops as ops + + from vllm_omni.diffusion.quantization.int8 import Int8LinearMethod + + method = Int8LinearMethod(quant_config) + + # Create layer with pre-processed weights + layer = torch.nn.Module() + weight_fp16 = torch.randn(128, 64, dtype=torch.float16, device="cuda") + qweight, scale, _ = ops.scaled_int8_quant(weight_fp16, scale=None) + layer.weight = torch.nn.Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = torch.nn.Parameter(scale, requires_grad=False) + + # Set required attributes for kernel + layer.input_scale = None + layer.input_zero_point = None + layer.azp_adj = None + + # Forward pass + x = torch.randn(2, 16, 64, dtype=torch.float16, device="cuda") + output = method.apply(layer, x) + + assert output.shape == (2, 16, 128) + assert output.dtype == torch.float16 diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index b68b6f70d94..b7367e9b324 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -279,6 +279,7 @@ def __init__( total_num_kv_heads=num_kv_heads, bias=False, quant_config=quant_config, + prefix="to_qkv", ) assert qk_norm is True @@ -297,6 +298,7 @@ def __init__( input_is_parallel=True, return_bias=False, quant_config=quant_config, + prefix="to_out", ) ] ) @@ -361,6 +363,7 @@ def __init__( dim: int, hidden_dim: int, quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() self.w13 = MergedColumnParallelLinear( @@ -369,6 +372,7 @@ def __init__( bias=False, return_bias=False, quant_config=quant_config, + prefix=prefix, ) self.act = SiluAndMul() self.w2 = RowParallelLinear( @@ -378,6 +382,7 @@ def __init__( input_is_parallel=True, return_bias=False, quant_config=quant_config, + prefix=prefix, ) def forward(self, x): @@ -409,9 +414,7 @@ def __init__( ) self.feed_forward = FeedForward( - dim=dim, - hidden_dim=int(dim / 3 * 8), - quant_config=quant_config, + dim=dim, hidden_dim=int(dim / 3 * 8), quant_config=quant_config, prefix="feed_forward" ) self.layer_id = layer_id diff --git a/vllm_omni/diffusion/quantization/__init__.py b/vllm_omni/diffusion/quantization/__init__.py index d297d51f18a..eb4b8ea9f5e 100644 --- a/vllm_omni/diffusion/quantization/__init__.py +++ b/vllm_omni/diffusion/quantization/__init__.py @@ -29,6 +29,7 @@ from .base import DiffusionQuantizationConfig from .fp8 import DiffusionFp8Config from .gguf import DiffusionGgufConfig +from .int8 import DiffusionInt8Config if TYPE_CHECKING: from vllm.model_executor.layers.quantization.base_config import ( @@ -41,6 +42,7 @@ # To add a new method, create a new config class and register it here _QUANT_CONFIG_REGISTRY: dict[str, type[DiffusionQuantizationConfig]] = { "fp8": DiffusionFp8Config, + "int8": DiffusionInt8Config, "gguf": DiffusionGgufConfig, } @@ -110,6 +112,7 @@ def get_vllm_quant_config_for_layers( __all__ = [ "DiffusionQuantizationConfig", "DiffusionFp8Config", + "DiffusionInt8Config", "DiffusionGgufConfig", "get_diffusion_quant_config", "get_vllm_quant_config_for_layers", diff --git a/vllm_omni/diffusion/quantization/int8.py b/vllm_omni/diffusion/quantization/int8.py new file mode 100644 index 00000000000..798038c9f7a --- /dev/null +++ b/vllm_omni/diffusion/quantization/int8.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""INT8 quantization config for diffusion transformers.""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.kernels.linear import ( + init_int8_linear_kernel, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import CopyNumelCounter, _copy_missing_attrs +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, +) +from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, +) +from vllm.model_executor.utils import replace_parameter + +from vllm_omni.platforms import current_omni_platform + +if current_omni_platform.is_npu(): + import torch_npu +else: + torch_npu = None + +from .base import DiffusionQuantizationConfig + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +# Dynamic quantization is supported first. +ACTIVATION_SCHEMES = ["dynamic"] + +logger = init_logger(__name__) + + +def create_weight_parameter( + output_size_per_partition: int, + input_size_per_partition: int, + weight_loader: Callable | None, + params_dtype: torch.dtype, +) -> torch.nn.Parameter: + """ + Create int8 weight parameter. + """ + return ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + +class Int8Config(QuantizationConfig): + """ + Config class for Int8. + """ + + def __init__( + self, + is_checkpoint_int8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: list[str] | None = None, + ) -> None: + super().__init__() + + self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized + + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "int8" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + # Have verified on A100 and H20, but not on oldest versions. + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.ignored_layers is not None: + self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "Int8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_int8_serialized = "int8" in quant_method + activation_scheme = cls.get_from_keys_or(config, ["activation_scheme"], "dynamic") + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + + if not ignored_layers: + ignored_layers = cls.get_from_keys_or(config, ["modules_to_not_convert"], None) + return cls( + is_checkpoint_int8_serialized=is_checkpoint_int8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + if not self.is_checkpoint_int8_serialized: + if current_omni_platform.is_cuda(): + online_method = Int8OnlineLinearMethod(self) + elif current_omni_platform.is_npu(): + online_method = NPUInt8OnlineLinearMethod(self) + else: + raise NotImplementedError("The current platform is not supported int8 online quant.") + return online_method + else: + if current_omni_platform.is_cuda(): + offline_method = Int8LinearMethod(self) + elif current_omni_platform.is_npu(): + offline_method = NPUInt8LinearMethod(self) + else: + raise NotImplementedError("The current platform is not supported int8 offline quant.") + return offline_method + return None + + +class BaseInt8LinearMethod(LinearMethodBase): + """ + Linear method for Int8 + Supports loading Int8 checkpoints with static weight scale and dynamic activation scale. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Int8Config): + self.quant_config = quant_config + self.out_dtype = torch.get_default_dtype() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + params_dtype = torch.int8 if self.quant_config.is_checkpoint_int8_serialized else params_dtype + weight = create_weight_parameter( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + weight_loader=weight_loader, + params_dtype=params_dtype, + ) + layer.register_parameter("weight", weight) + + if self.quant_config.is_checkpoint_int8_serialized: + scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", scale) + + def process_weights_after_loading(self, layer: Module) -> None: + raise NotImplementedError("No BaseInt8LinearMethod process_weights_after_loading implementation.") + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + raise NotImplementedError("No BaseInt8LinearMethod apply implementation.") + + +class LazyWeightMixin: + """ + Mixin for lazy weight loading with meta device. + weighs are created on meta device and materialized just-in-time during loadding. + """ + + uses_meta_device: bool = True + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # WEIGHT + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + + # when the first `loaded_weight` is about to be + # loaded to `param`, materialize `param` just-in-time + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=patched_weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + del layer._load_device + + # refresh the reference to `param` to reflect just-in-time + # materialization + param = layer.weight + + # load the current weight chunk + copy_numel_counter = CopyNumelCounter() + with copy_numel_counter: + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + layer._loaded_numel += copy_numel_counter.copied_numel + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Prevent the usual `process_weights_after_loading` call from doing + # anything + layer._already_called_process_weights_after_loading = True + + # Note that we keep `layer._loaded_numel` around just in case + # there is logic added to vllm in the future which calls a + # weight loader twice - we do not want to re-initialize in + # that case. + + return res + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + # materialized just-in-time in `patched_weight_loader` + device="meta", + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=patched_weight_loader, + ) + # stash the correct device for `patched_weight_loader` + layer._load_device = torch.get_default_device() + layer.register_parameter("weight", weight) + + +class Int8LinearMethod(BaseInt8LinearMethod): + """ + Linear method for Int8 + Supports loading Int8 checkpoints. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Int8Config): + super().__init__(quant_config) + + self.int8_linear = init_int8_linear_kernel( + is_channelwise=False, + is_static_input_scheme=False, + input_symmetric=True, + module_name=self.__class__.__name__, + ) + + def process_weights_after_loading(self, layer: Module) -> None: + self.int8_linear.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.int8_linear.apply_weights(layer, x, bias) + + +class NPUInt8LinearMethod(BaseInt8LinearMethod): + """ + NPU Linear method for Int8 + Supports loading Int8 checkpoints. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Int8Config): + super().__init__(quant_config) + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight.data = layer.weight.data.t().contiguous() + layer.weight_scale.data = layer.weight_scale.data.squeeze() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + ori_shape = x.shape + ori_dtype = x.dtype + + x = x.reshape(-1, ori_shape[-1]) + quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x) + + output = torch_npu.npu_quant_matmul( + quantized_x, + layer.weight, + layer.weight_scale, + bias=bias, + pertoken_scale=pertoken_scale, + output_dtype=ori_dtype, + ) + output = output.reshape(*ori_shape[:-1], -1) + return output + + +class Int8OnlineLinearMethod(LazyWeightMixin, Int8LinearMethod): + """ + Online version of Int8LinearMethod, loads the fp16/bf16 checkpoint + and quantized the weights during loading. + """ + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + + w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.int8_linear.layer_param_names + qweight, weight_scale, _ = ops.scaled_int8_quant(layer.weight, scale=None) + + # Update layer with new values. + replace_parameter(layer, w_q_name, torch.nn.Parameter(qweight.t().data, requires_grad=False)) + replace_parameter(layer, w_s_name, torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + setattr(layer, i_s_name, None) + setattr(layer, i_zp_name, None) + setattr(layer, azp_adj_name, None) + + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + + +class NPUInt8OnlineLinearMethod(LazyWeightMixin, NPUInt8LinearMethod): + """ + NPU Online version of Int8LinearMethod, loads the fp16/bf16 checkpoint + and quantized the weights during loading. + """ + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + if layer.weight.device == torch.device("meta"): + weight = ModelWeightParameter( + data=torch.empty_like(layer.weight, device=layer._load_device), + input_dim=1, + output_dim=0, + weight_loader=layer.weight.weight_loader, + ) + _copy_missing_attrs(layer.weight, weight) + layer.register_parameter("weight", weight) + initialize_single_dummy_weight(layer.weight) + + weight = layer.weight + qweight, weight_scale = torch_npu.npu_dynamic_quant(weight) + + qweight = qweight.t().contiguous() + + # Update layer with new values. + replace_parameter(layer, "weight", qweight) + replace_parameter(layer, "weight_scale", weight_scale) + + # Prevent duplicate processing (e.g., during weight reload) + layer._already_called_process_weights_after_loading = True + + +class DiffusionInt8Config(DiffusionQuantizationConfig): + """ + Int8 quantization config optimized for diffusion transformers. + + Args: + activation_scheme: Activation quantization scheme. + - "dynamic": Per-token dynamic scaling (default, no calibration) + ignored_layers: List of layer name patterns to skip quantization. + """ + + quant_config_cls = Int8Config + + def __init__( + self, + activation_scheme: str = "dynamic", + ignored_layers: list[str] | None = None, + ): + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + + # Create underlying vLLM Int8 config + self._vllm_config = Int8Config( + is_checkpoint_int8_serialized=False, # Online quantization + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + ) diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 50fb6759793..56eb3b5cc56 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -844,6 +844,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "enable_sleep_mode": kwargs.get("enable_sleep_mode", False), "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4), + "quantization": kwargs.get("quantization", None), "enable_diffusion_pipeline_profiler": kwargs.get("enable_diffusion_pipeline_profiler", False), }, "final_output": True,