-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[AutoRound] Add offline quantized W4A16 model support
#1777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
c376e17
[Feature] Add INC/AutoRound W4A16 quantization for diffusion models
yiliu30 e89966e
[Feature] Add quantizable AdaLayerNorm and prefix propagation for FLUX
yiliu30 d881555
[Test] Add tests for INC quantization, AdaLayerNorm, and FLUX prefix
yiliu30 b510ba7
[Docs] Add AutoRound quantization guide
yiliu30 2d621d5
Merge branch 'main' into feats/ar-w4a16
yiliu30 f5af240
Merge branch 'main' into feats/ar-w4a16
yiliu30 dd756a0
Address review feedback: fix weight validation, model name, and getat…
yiliu30 9543d6f
Merge branch 'main' into feats/ar-w4a16
yiliu30 f444a6e
Merge branch 'main' into feats/ar-w4a16
yiliu30 21fd0e9
Merge branch 'main' into feats/ar-w4a16
yiliu30 4c318d3
[Bugfix] fix CI unit test failures
yiliu30 f706df5
Merge branch 'main' into feats/ar-w4a16
yiliu30 7c82e87
fix
yiliu30 bde02bd
update
yiliu30 b06f196
Merge branch 'main' into feats/ar-w4a16
yiliu30 b56a289
Merge branch 'main' into feats/ar-w4a16
yiliu30 23fe86a
Merge branch 'main' into feats/ar-w4a16
yiliu30 13292d5
Merge branch 'main' into feats/ar-w4a16
hsliuustc0106 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # AutoRound Quantization | ||
|
|
||
| ## Overview | ||
|
|
||
| [AutoRound](https://github.com/intel/auto-round) is an advanced quantization toolkit designed for Large Language Models (LLMs), Vision-Language Models (VLMs), and diffusion models. It achieves high accuracy at ultra-low bit widths (2–4 bits) with minimal tuning by leveraging sign-gradient descent, while providing broad hardware compatibility with multi-datatype support. | ||
|
|
||
| The quantization config is auto-detected from the checkpoint's `config.json` (`quantization_config.quant_method = "auto-round"`). No extra CLI flags are needed. | ||
|
|
||
| ### Supported Schemes | ||
|
|
||
| | Scheme | Bits | Status | | ||
| |--------|------|--------| | ||
| | W4A16 | 4 | ✅ Supported | | ||
| | W8A16 | 8 | Planned | | ||
|
|
||
| W4A16 is the first supported scheme. Additional schemes will be added in future releases. | ||
|
|
||
| ## Configuration | ||
|
|
||
| 1. **Python API**: point `model` at a pre-quantized checkpoint. The quantization is detected automatically. | ||
|
|
||
| ```python | ||
| from vllm_omni import Omni | ||
| from vllm_omni.inputs.data import OmniDiffusionSamplingParams | ||
|
|
||
| omni = Omni(model="vllm-project-org/FLUX.1-dev-AutoRound-w4a16") | ||
|
|
||
|
yiliu30 marked this conversation as resolved.
|
||
| outputs = omni.generate( | ||
| "A cat sitting on a windowsill", | ||
| OmniDiffusionSamplingParams(num_inference_steps=28), | ||
| ) | ||
| outputs[0].save_images("output.png") | ||
| ``` | ||
|
|
||
| 2. **CLI**: pass the quantized model path directly. | ||
|
|
||
| ```bash | ||
| python examples/offline_inference/text_to_image/text_to_image.py \ | ||
| --model vllm-project-org/FLUX.1-dev-AutoRound-w4a16 \ | ||
| --prompt "A cat sitting on a windowsill" \ | ||
| --num-inference-steps 28 \ | ||
| --output outputs/flux_w4a16.png | ||
| ``` | ||
|
|
||
| No `--quantization` flag is needed — the quantization method is read from the checkpoint. | ||
|
|
||
| ## How It Works | ||
|
|
||
| The checkpoint's `config.json` contains: | ||
|
|
||
| ```json | ||
| { | ||
| "quantization_config": { | ||
| "quant_method": "auto-round", | ||
| "bits": 4, | ||
| "group_size": 128, | ||
| "sym": true, | ||
| "packing_format": "auto_round:auto_gptq", | ||
| "block_name_to_quantize": "transformer_blocks,single_transformer_blocks" | ||
| } | ||
| } | ||
| ``` | ||
|
|
||
| At load time: | ||
|
|
||
| 1. `TransformerConfig.from_dict()` parses the `quantization_config` section and builds a vLLM `INCConfig` via `build_quant_config("auto-round", ...)`. | ||
| 2. `OmniDiffusionConfig.set_tf_model_config()` propagates the detected config to the engine. | ||
| 3. The appropriate compute kernel (e.g. GPTQ-Marlin for W4A16) is selected automatically based on the checkpoint's bit-width and packing format. | ||
|
|
||
| ## Supported Models | ||
|
|
||
| | Model | HF Checkpoint | Scheme | Group Size | Backend | | ||
| |-------|--------------|--------|------------|---------| | ||
| | FLUX.1-dev | `vllm-project-org/FLUX.1-dev-AutoRound-w4a16` | W4A16 | 128 | GPTQ-Marlin | | ||
|
|
||
| ## Creating a Quantized Checkpoint | ||
|
|
||
| Use [AutoRound](https://github.com/intel/auto-round) to quantize a BF16 model. The `--scheme` flag selects the quantization scheme: | ||
|
|
||
| ```bash | ||
| # W4A16 (4-bit weight, 16-bit activation) | ||
| auto-round \ | ||
| --model black-forest-labs/FLUX.1-dev \ | ||
| --scheme W4A16 \ | ||
| --batch_size 1 \ | ||
| --disable_opt_rtn \ | ||
| --dataset coco2014 \ | ||
| --iters 0 | ||
| ``` | ||
|
|
||
| The output directory can be used directly as the `model` argument. See the [AutoRound documentation](https://github.com/intel/auto-round) for all available schemes and options. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,237 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Unit tests for shared AdaLayerNorm layers used by FLUX and other models.""" | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _init_distributed(): | ||
| """Initialize the minimal distributed environment required by | ||
| ReplicatedLinear (tensor-parallel group must exist).""" | ||
| from vllm.distributed.parallel_state import ( | ||
| cleanup_dist_env_and_memory, | ||
| init_distributed_environment, | ||
| initialize_model_parallel, | ||
| ) | ||
|
|
||
| os.environ.setdefault("MASTER_ADDR", "localhost") | ||
| os.environ.setdefault("MASTER_PORT", "29501") | ||
| init_distributed_environment( | ||
| world_size=1, | ||
| rank=0, | ||
| local_rank=0, | ||
| distributed_init_method="env://", | ||
| ) | ||
| initialize_model_parallel() | ||
| yield | ||
| cleanup_dist_env_and_memory() | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _force_default_gemm(monkeypatch): | ||
| """Force CPU-compatible GEMM dispatch for tests using CPU tensors. | ||
|
|
||
| vLLM's dispatch_unquantized_gemm() selects the backend by platform | ||
| (e.g. rocm_unquantized_gemm on AMD machines), not by tensor device. | ||
| CPU test tensors crash with NotImplementedError on ROCm. Monkeypatch | ||
| the dispatcher to always return the default (torch.nn.functional.linear) | ||
| implementation which works on any device.""" | ||
| from vllm.model_executor.layers.utils import default_unquantized_gemm | ||
|
|
||
| monkeypatch.setattr( | ||
| "vllm.model_executor.layers.linear.dispatch_unquantized_gemm", | ||
| lambda: default_unquantized_gemm, | ||
| ) | ||
|
|
||
|
|
||
| def test_adalayernorm_import_from_shared_module(): | ||
| """Verify imports work from the shared adalayernorm module.""" | ||
| from vllm_omni.diffusion.layers.adalayernorm import ( # noqa: F401 | ||
| AdaLayerNormContinuous, | ||
| AdaLayerNormZero, | ||
| AdaLayerNormZeroSingle, | ||
| ) | ||
|
|
||
|
|
||
| def test_adalayernorm_zero_forward_shape(): | ||
| """AdaLayerNormZero produces correct output shapes (x, gate, shift, scale, gate).""" | ||
| from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNormZero | ||
|
|
||
| dim = 64 | ||
| batch = 2 | ||
| seq_len = 4 | ||
| norm = AdaLayerNormZero(dim) | ||
|
|
||
| x = torch.randn(batch, seq_len, dim) | ||
| emb = torch.randn(batch, dim) | ||
|
|
||
| out_x, gate_msa, shift_mlp, scale_mlp, gate_mlp = norm(x, emb) | ||
|
|
||
| assert out_x.shape == (batch, seq_len, dim) | ||
| assert gate_msa.shape == (batch, dim) | ||
| assert shift_mlp.shape == (batch, dim) | ||
| assert scale_mlp.shape == (batch, dim) | ||
| assert gate_mlp.shape == (batch, dim) | ||
|
|
||
|
|
||
| def test_adalayernorm_zero_single_forward_shape(): | ||
| """AdaLayerNormZeroSingle produces (x, gate) with correct shapes.""" | ||
| from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNormZeroSingle | ||
|
|
||
| dim = 64 | ||
| batch = 2 | ||
| seq_len = 4 | ||
| norm = AdaLayerNormZeroSingle(dim) | ||
|
|
||
| x = torch.randn(batch, seq_len, dim) | ||
| emb = torch.randn(batch, dim) | ||
|
|
||
| out_x, gate = norm(x, emb) | ||
|
|
||
| assert out_x.shape == (batch, seq_len, dim) | ||
| assert gate.shape == (batch, dim) | ||
|
|
||
|
|
||
| def test_adalayernorm_continuous_forward_shape(): | ||
| """AdaLayerNormContinuous produces correct output shape.""" | ||
| from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNormContinuous | ||
|
|
||
| dim = 64 | ||
| cond_dim = 64 | ||
| batch = 2 | ||
| seq_len = 4 | ||
| norm = AdaLayerNormContinuous(dim, cond_dim) | ||
|
|
||
| x = torch.randn(batch, seq_len, dim) | ||
| conditioning = torch.randn(batch, cond_dim) | ||
|
|
||
| out = norm(x, conditioning) | ||
|
|
||
| assert out.shape == (batch, seq_len, dim) | ||
|
|
||
|
|
||
| def test_adalayernorm_zero_accepts_quant_config(): | ||
| """Constructor accepts quant_config=None and prefix='test' without error.""" | ||
| from vllm_omni.diffusion.layers.adalayernorm import ( | ||
| AdaLayerNormContinuous, | ||
| AdaLayerNormZero, | ||
| AdaLayerNormZeroSingle, | ||
| ) | ||
|
|
||
| # Should not raise with quant_config=None and prefix | ||
| AdaLayerNormZero(64, quant_config=None, prefix="test.norm1") | ||
| AdaLayerNormZeroSingle(64, quant_config=None, prefix="test.norm") | ||
| AdaLayerNormContinuous(64, 64, quant_config=None, prefix="test.norm_out") | ||
|
|
||
|
|
||
| def test_adalayernorm_uses_replicated_linear(): | ||
| """Verify .linear is a ReplicatedLinear instance (not nn.Linear).""" | ||
| from vllm.model_executor.layers.linear import ReplicatedLinear | ||
|
|
||
| from vllm_omni.diffusion.layers.adalayernorm import ( | ||
| AdaLayerNormContinuous, | ||
| AdaLayerNormZero, | ||
| AdaLayerNormZeroSingle, | ||
| ) | ||
|
|
||
| norm_zero = AdaLayerNormZero(64) | ||
| assert isinstance(norm_zero.linear, ReplicatedLinear) | ||
|
|
||
| norm_zero_single = AdaLayerNormZeroSingle(64) | ||
| assert isinstance(norm_zero_single.linear, ReplicatedLinear) | ||
|
|
||
| norm_continuous = AdaLayerNormContinuous(64, 64) | ||
| assert isinstance(norm_continuous.linear, ReplicatedLinear) | ||
|
|
||
|
|
||
| # ── Numerical equivalence tests against diffusers originals ── | ||
|
|
||
|
|
||
| def _copy_weights(src_linear, dst_replicated_linear): | ||
| """Copy weights from nn.Linear to ReplicatedLinear for comparison.""" | ||
| dst_replicated_linear.weight.data.copy_(src_linear.weight.data) | ||
| if src_linear.bias is not None and dst_replicated_linear.bias is not None: | ||
| dst_replicated_linear.bias.data.copy_(src_linear.bias.data) | ||
|
|
||
|
|
||
| def test_adalayernorm_zero_matches_diffusers(): | ||
| """Verify AdaLayerNormZero produces identical output to diffusers original.""" | ||
| from diffusers.models.normalization import ( | ||
| AdaLayerNormZero as DiffusersAdaLayerNormZero, | ||
| ) | ||
|
|
||
| from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNormZero | ||
|
|
||
| dim = 64 | ||
| torch.manual_seed(42) | ||
| ours = AdaLayerNormZero(dim) | ||
| ref = DiffusersAdaLayerNormZero(dim) | ||
|
|
||
| # Copy weights: nn.Linear -> ReplicatedLinear | ||
| _copy_weights(ref.linear, ours.linear) | ||
|
|
||
| x = torch.randn(2, 4, dim) | ||
| emb = torch.randn(2, dim) | ||
|
|
||
| out_ours = ours(x, emb) | ||
| out_ref = ref(x, emb=emb) | ||
|
|
||
| for o, r in zip(out_ours, out_ref): | ||
| torch.testing.assert_close(o, r, atol=1e-5, rtol=1e-5) | ||
|
|
||
|
|
||
| def test_adalayernorm_zero_single_matches_diffusers(): | ||
| """Verify AdaLayerNormZeroSingle produces identical output to diffusers original.""" | ||
| from diffusers.models.normalization import ( | ||
| AdaLayerNormZeroSingle as DiffusersAdaLayerNormZeroSingle, | ||
| ) | ||
|
|
||
| from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNormZeroSingle | ||
|
|
||
| dim = 64 | ||
| torch.manual_seed(42) | ||
| ours = AdaLayerNormZeroSingle(dim) | ||
| ref = DiffusersAdaLayerNormZeroSingle(dim) | ||
|
|
||
| _copy_weights(ref.linear, ours.linear) | ||
|
|
||
| x = torch.randn(2, 4, dim) | ||
| emb = torch.randn(2, dim) | ||
|
|
||
| out_ours = ours(x, emb) | ||
| out_ref = ref(x, emb=emb) | ||
|
|
||
| for o, r in zip(out_ours, out_ref): | ||
| torch.testing.assert_close(o, r, atol=1e-5, rtol=1e-5) | ||
|
|
||
|
|
||
| def test_adalayernorm_continuous_matches_diffusers(): | ||
| """Verify AdaLayerNormContinuous produces identical output to diffusers original.""" | ||
| from diffusers.models.normalization import ( | ||
| AdaLayerNormContinuous as DiffusersAdaLayerNormContinuous, | ||
| ) | ||
|
|
||
| from vllm_omni.diffusion.layers.adalayernorm import AdaLayerNormContinuous | ||
|
|
||
| dim = 64 | ||
| cond_dim = 64 | ||
| torch.manual_seed(42) | ||
| # Match constructor args: diffusers defaults elementwise_affine=True, eps=1e-5 | ||
| ours = AdaLayerNormContinuous(dim, cond_dim, elementwise_affine=False, eps=1e-6) | ||
| ref = DiffusersAdaLayerNormContinuous(dim, cond_dim, elementwise_affine=False, eps=1e-6) | ||
|
|
||
| _copy_weights(ref.linear, ours.linear) | ||
|
|
||
| x = torch.randn(2, 4, dim) | ||
| cond = torch.randn(2, cond_dim) | ||
|
|
||
| out_ours = ours(x, cond) | ||
| out_ref = ref(x, cond) | ||
|
|
||
| torch.testing.assert_close(out_ours, out_ref, atol=1e-5, rtol=1e-5) |
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.