diff --git a/docs/user_guide/diffusion/quantization/fp8.md b/docs/user_guide/diffusion/quantization/fp8.md index 95a62b05d7..9b105463b5 100644 --- a/docs/user_guide/diffusion/quantization/fp8.md +++ b/docs/user_guide/diffusion/quantization/fp8.md @@ -34,7 +34,7 @@ outputs = omni.generate( ) ``` -2. **CLI**: pass `--quantization fp8` and optionally `--ignored-layers`. +2. **CLI**: pass `--diffusion-quantization fp8` (for `vllm serve --omni`) and optionally `--ignored-layers`. ```bash # All layers @@ -44,7 +44,7 @@ python text_to_image.py --model --quantization fp8 python text_to_image.py --model --quantization fp8 --ignored-layers "img_mlp" # Online serving -vllm serve --omni --quantization fp8 +vllm serve --omni --diffusion-quantization fp8 ``` | Parameter | Type | Default | Description | diff --git a/pyproject.toml b/pyproject.toml index 15e408a966..9e97de4f1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,10 @@ dev = [ "pyttsx3>=2.99" ] +quant = [ + "bitsandbytes>=0.49.0", +] + docs = [ "mkdocs>=1.5.0", "mkdocs-api-autonav", diff --git a/tests/diffusion/test_bitsandbytes_quantization.py b/tests/diffusion/test_bitsandbytes_quantization.py new file mode 100644 index 0000000000..acf1dbab13 --- /dev/null +++ b/tests/diffusion/test_bitsandbytes_quantization.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import builtins +import sys +import types + +import pytest +import torch +import torch.nn as nn + +from vllm_omni.diffusion.data import OmniDiffusionConfig +from vllm_omni.diffusion.quantization import bitsandbytes as bnb_module +from vllm_omni.diffusion.quantization.bitsandbytes import ( + DiffusionBitsAndBytesConfig, + apply_bnb_quantization, + patch_transformers_for_bnb_load, +) + +_DUMMY_BNB_WEIGHT = 0.123 + + +def _install_dummy_bnb(monkeypatch: pytest.MonkeyPatch): + class DummyLinear8bitLt(nn.Linear): + def __init__(self, in_features, out_features, bias=True, has_fp16_weights=False, device=None, **kwargs): + super().__init__(in_features, out_features, bias=bias, device=device) + self.has_fp16_weights = has_fp16_weights + nn.init.constant_(self.weight, _DUMMY_BNB_WEIGHT) + if self.bias is not None: + nn.init.zeros_(self.bias) + + class DummyLinear4bit(nn.Linear): + def __init__( + self, + in_features, + out_features, + bias=True, + compute_dtype=None, + compress_statistics=False, + quant_type="fp4", + device=None, + **kwargs, + ): + super().__init__(in_features, out_features, bias=bias, device=device) + self.compute_dtype = compute_dtype + self.compress_statistics = compress_statistics + self.quant_type = quant_type + + dummy_bnb = types.SimpleNamespace( + nn=types.SimpleNamespace( + Linear8bitLt=DummyLinear8bitLt, + Linear4bit=DummyLinear4bit, + ) + ) + monkeypatch.setitem(sys.modules, "bitsandbytes", dummy_bnb) + return dummy_bnb + + +def test_quant_config_normalization(): + cfg = OmniDiffusionConfig( + model="dummy-model", + quantization="BNB_4BIT", + quantization_config={ + "modules": "transformer, text_encoder_2", + "bnb_4bit_compute_dtype": "fp16", + }, + ) + assert isinstance(cfg.quantization_config, DiffusionBitsAndBytesConfig) + assert cfg.quantization_config.load_in_4bit is True + assert cfg.quantization_config.load_in_8bit is False + assert cfg.quantization_config.modules == ["transformer", "text_encoder_2"] + assert cfg.quantization_config.bnb_4bit_compute_dtype == torch.float16 + + +def test_apply_bnb_quantization_replaces_linear_modules(monkeypatch): + bnb = _install_dummy_bnb(monkeypatch) + + class DummyPipeline(nn.Module): + def __init__(self): + super().__init__() + self.transformer = nn.Sequential( + nn.Linear(4, 8, bias=True), + nn.ReLU(), + nn.ModuleList([nn.Linear(8, 8, bias=False), nn.Sequential(nn.Linear(8, 4))]), + ) + + pipeline = DummyPipeline() + cfg = OmniDiffusionConfig( + model="dummy-model", + quantization="bitsandbytes", + quantization_config={"load_in_8bit": True, "modules": ["transformer"]}, + ) + assert isinstance(cfg.quantization_config, DiffusionBitsAndBytesConfig) + apply_bnb_quantization(pipeline, cfg.quantization_config) + + assert isinstance(pipeline.transformer[0], bnb.nn.Linear8bitLt) + assert isinstance(pipeline.transformer[2][0], bnb.nn.Linear8bitLt) + assert isinstance(pipeline.transformer[2][1][0], bnb.nn.Linear8bitLt) + + +def test_apply_bnb_quantization_copy_weights_false_pre_replace(monkeypatch): + _install_dummy_bnb(monkeypatch) + + class DummyPipeline(nn.Module): + def __init__(self): + super().__init__() + self.transformer = nn.Sequential(nn.Linear(4, 4, bias=False)) + + pipeline = DummyPipeline() + pipeline.transformer[0].weight.data.zero_() + cfg = OmniDiffusionConfig( + model="dummy-model", + quantization="bitsandbytes", + quantization_config={"load_in_8bit": True, "modules": ["transformer"]}, + ) + apply_bnb_quantization(pipeline, cfg.quantization_config, copy_weights=False) + + assert isinstance(pipeline.transformer[0], nn.Linear) + assert torch.allclose( + pipeline.transformer[0].weight, + torch.full_like(pipeline.transformer[0].weight, _DUMMY_BNB_WEIGHT), + ) + + +def test_bnb_llm_int8_has_fp16_weight_passed(monkeypatch): + _install_dummy_bnb(monkeypatch) + + class DummyPipeline(nn.Module): + def __init__(self): + super().__init__() + self.transformer = nn.Sequential(nn.Linear(4, 4, bias=False)) + + pipeline = DummyPipeline() + cfg = OmniDiffusionConfig( + model="dummy-model", + quantization="bitsandbytes", + quantization_config={ + "load_in_8bit": True, + "modules": ["transformer"], + "llm_int8_has_fp16_weight": True, + }, + ) + apply_bnb_quantization(pipeline, cfg.quantization_config, copy_weights=False) + + assert getattr(pipeline.transformer[0], "has_fp16_weights", False) is True + + +def test_bnb_pre_replace_no_false_warning(monkeypatch, caplog): + from vllm.logger import _print_warning_once + + _print_warning_once.cache_clear() + _install_dummy_bnb(monkeypatch) + + class DummyPipeline(nn.Module): + def __init__(self): + super().__init__() + self.transformer = nn.Sequential(nn.Linear(4, 4, bias=False)) + + pipeline = DummyPipeline() + cfg = OmniDiffusionConfig( + model="dummy", + quantization="bitsandbytes", + quantization_config={"load_in_8bit": True, "modules": ["transformer"]}, + ) + + with caplog.at_level("WARNING"): + apply_bnb_quantization(pipeline, cfg.quantization_config, copy_weights=False) + apply_bnb_quantization(pipeline, cfg.quantization_config, copy_weights=True) + + assert not any("no Linear layers replaced" in r.message for r in caplog.records) + + +def test_hf_bnb_patch_inject_and_restore(monkeypatch): + from vllm.logger import _print_warning_once + + _print_warning_once.cache_clear() + + class DummyBitsAndBytesConfig: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class DummyPreTrainedModel: + @classmethod + def from_pretrained(cls, *args, **kwargs): + return kwargs + + transformers_mod = types.ModuleType("transformers") + transformers_mod.BitsAndBytesConfig = DummyBitsAndBytesConfig + modeling_utils_mod = types.ModuleType("transformers.modeling_utils") + modeling_utils_mod.PreTrainedModel = DummyPreTrainedModel + + monkeypatch.setitem(sys.modules, "transformers", transformers_mod) + monkeypatch.setitem(sys.modules, "transformers.modeling_utils", modeling_utils_mod) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + + cfg = DiffusionBitsAndBytesConfig(load_in_8bit=True, modules=["transformer"]) + orig_attr = DummyPreTrainedModel.__dict__["from_pretrained"] + + with patch_transformers_for_bnb_load(cfg, device=torch.device("cuda")) as used: + out = DummyPreTrainedModel.from_pretrained("transformer", subfolder="transformer") + assert "quantization_config" in out + assert "device_map" in out + assert "transformer" in used + + assert DummyPreTrainedModel.__dict__["from_pretrained"] is orig_attr + + +def test_vllm_linear_bnb4_return_bias_semantics(monkeypatch): + dummy_bnb = _install_dummy_bnb(monkeypatch) + + def matmul_4bit(x, w_t, quant_state): + return x @ w_t + + dummy_bnb.matmul_4bit = matmul_4bit + monkeypatch.setitem(sys.modules, "bitsandbytes", dummy_bnb) + + class DummyVllmLinear(nn.Module): + def __init__(self, return_bias: bool, skip_bias_add: bool): + super().__init__() + self.weight = nn.Parameter(torch.randn(4, 4)) + self.bias = nn.Parameter(torch.randn(4)) + self.return_bias = return_bias + self.skip_bias_add = skip_bias_add + self.quant_method = None + + def forward(self, x: torch.Tensor): + bias = self.bias if not self.skip_bias_add else None + out = self.quant_method.apply(self, x, bias) + if not self.return_bias: + return out + output_bias = self.bias if self.skip_bias_add else None + return out, output_bias + + method = bnb_module._DiffusionBnbLinearMethod(compute_dtype=torch.float32) + + x = torch.randn(2, 4) + linear = DummyVllmLinear(return_bias=True, skip_bias_add=True) + linear.weight.quant_state = object() + linear.quant_method = method + out, out_bias = linear(x) + assert torch.allclose(out, x @ linear.weight.t()) + assert out_bias is linear.bias + + linear2 = DummyVllmLinear(return_bias=True, skip_bias_add=False) + linear2.weight.quant_state = object() + linear2.quant_method = method + out2, out_bias2 = linear2(x) + assert torch.allclose(out2, x @ linear2.weight.t() + linear2.bias) + assert out_bias2 is None + + +def test_apply_bnb_quantization_missing_bnb_raises(monkeypatch): + orig_import = builtins.__import__ + + def _fake_import(name, *args, **kwargs): + if name == "bitsandbytes": + raise ImportError("bitsandbytes missing") + return orig_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", _fake_import) + + pipeline = nn.Sequential(nn.Linear(4, 4)) + cfg = OmniDiffusionConfig( + model="dummy-model", + quantization="bitsandbytes", + quantization_config={"load_in_8bit": True}, + ) + + with pytest.raises(ImportError, match="bitsandbytes is required"): + apply_bnb_quantization(pipeline, cfg.quantization_config) + + +def test_bnb_config_requires_load_in_flag(): + with pytest.raises(ValueError, match="requires load_in_8bit or load_in_4bit"): + DiffusionBitsAndBytesConfig(load_in_8bit=False, load_in_4bit=False) diff --git a/tests/diffusion/test_offload_bnb_interaction.py b/tests/diffusion/test_offload_bnb_interaction.py new file mode 100644 index 0000000000..6f722b00a0 --- /dev/null +++ b/tests/diffusion/test_offload_bnb_interaction.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch.nn as nn + +from vllm_omni.diffusion.offloader import sequential_backend +from vllm_omni.diffusion.offloader.base import OffloadConfig, OffloadStrategy +from vllm_omni.diffusion.quantization.bitsandbytes import set_bnb_offload_skip_components + + +def test_model_level_offload_respects_bnb_skip(monkeypatch): + captured: dict[str, list[nn.Module]] = {} + + def _fake_apply_sequential_offload(*, offload_dit_modules, offload_encoder_modules, **kwargs): + captured["offload_dit_modules"] = list(offload_dit_modules) + captured["offload_encoder_modules"] = list(offload_encoder_modules) + + monkeypatch.setattr(sequential_backend, "apply_sequential_offload", _fake_apply_sequential_offload) + + class DummyPipeline(nn.Module): + def __init__(self): + super().__init__() + self.transformer = nn.Linear(4, 4) + self.text_encoder = nn.Linear(4, 4) + + pipeline = DummyPipeline() + set_bnb_offload_skip_components(pipeline, {"transformer"}) + + backend = sequential_backend.ModelLevelOffloadBackend( + OffloadConfig(strategy=OffloadStrategy.MODEL_LEVEL), + device=torch.device("cpu"), + ) + backend.enable(pipeline) + + assert pipeline.transformer not in captured["offload_dit_modules"] + assert pipeline.text_encoder in captured["offload_encoder_modules"] diff --git a/tests/entrypoints/test_omni_stage_diffusion_config.py b/tests/entrypoints/test_omni_stage_diffusion_config.py index f464c55fd6..99885cd6e1 100644 --- a/tests/entrypoints/test_omni_stage_diffusion_config.py +++ b/tests/entrypoints/test_omni_stage_diffusion_config.py @@ -13,6 +13,8 @@ def test_build_od_config_includes_diffusion_fields(): "cache_backend": "cache_dit", "cache_config": {"Fn_compute_blocks": 2}, "vae_use_slicing": True, + "quantization": "bitsandbytes", + "quantization_config": {"method": "bitsandbytes", "modules": ["text_encoder"], "load_in_8bit": True}, } od_config = _build_od_config(engine_args, model="dummy-model") @@ -20,6 +22,8 @@ def test_build_od_config_includes_diffusion_fields(): assert od_config["cache_backend"] == "cache_dit" assert od_config["cache_config"]["Fn_compute_blocks"] == 2 assert od_config["vae_use_slicing"] is True + assert od_config["quantization"] == "bitsandbytes" + assert od_config["quantization_config"]["modules"] == ["text_encoder"] def test_build_od_config_respects_explicit_config(): diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index b67736c818..a2dc62e8c7 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -13,10 +13,13 @@ from typing_extensions import Self from vllm.config.utils import config from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm_omni.diffusion.quantization import ( + DiffusionBitsAndBytesConfig, DiffusionQuantizationConfig, get_diffusion_quant_config, + normalize_diffusion_quant_method, ) from vllm_omni.diffusion.utils.network_utils import is_port_available @@ -328,6 +331,11 @@ class OmniDiffusionConfig: # Attention attention_backend: str | None = None + # Quantization (optional) + # NOTE: This is diffusion-only quantization (not LLM quantization). + quantization: str | None = None + quantization_config: "DiffusionQuantizationConfig | dict[str, Any] | None" = None + # Running mode # mode: ExecutionMode = ExecutionMode.INFERENCE @@ -559,6 +567,12 @@ def __post_init__(self): # If it's neither dict nor DiffusionCacheConfig, convert to empty config self.cache_config = DiffusionCacheConfig() + def _normalize_method_for_compare(method: str | None) -> str | None: + normalized = normalize_diffusion_quant_method(method) + if normalized in ("bitsandbytes_8bit", "bitsandbytes_4bit"): + return "bitsandbytes" + return normalized + # Convert quantization config (deferred import to avoid circular imports) if self.quantization is not None or self.quantization_config is not None: from vllm_omni.diffusion.quantization import ( @@ -573,9 +587,14 @@ def __post_init__(self): quant_method = config_dict.get("method", self.quantization) # Filter out "method" key for kwargs quant_kwargs = {k: v for k, v in config_dict.items() if k != "method"} + quant_method = normalize_diffusion_quant_method(quant_method, quant_kwargs) # Validate conflicting methods - if self.quantization is not None and quant_method is not None and quant_method != self.quantization: + if ( + self.quantization is not None + and quant_method is not None + and _normalize_method_for_compare(self.quantization) != _normalize_method_for_compare(quant_method) + ): logger.warning( f"Conflicting quantization methods: quantization={self.quantization!r}, " f"quantization_config['method']={quant_method!r}. Using quantization_config['method']." @@ -583,13 +602,27 @@ def __post_init__(self): self.quantization_config = get_diffusion_quant_config(quant_method, **quant_kwargs) elif self.quantization_config is None and self.quantization is not None: - self.quantization_config = get_diffusion_quant_config(self.quantization) + quant_kwargs: dict[str, Any] = {} + quant_method = normalize_diffusion_quant_method(self.quantization, quant_kwargs) + self.quantization_config = get_diffusion_quant_config(quant_method, **quant_kwargs) elif not isinstance(self.quantization_config, DiffusionQuantizationConfig): raise TypeError( f"quantization_config must be a DiffusionQuantizationConfig, dict, or None, " f"got {type(self.quantization_config)!r}" ) + if isinstance(self.quantization_config, DiffusionBitsAndBytesConfig): + if current_platform.is_rocm(): + try: + from vllm.platforms.rocm import on_gfx9 + except ImportError: + on_gfx9 = None + if on_gfx9 is not None and on_gfx9(): + raise ValueError("bitsandbytes is not supported on ROCm gfx9 GPUs.") + if self.quantization_config.load_in_8bit and not self.enforce_eager: + logger.warning("CUDA graph is not supported on BitsAndBytes 8bit yet, fallback to the eager mode.") + self.enforce_eager = True + if self.max_cpu_loras is None: self.max_cpu_loras = 1 elif self.max_cpu_loras < 1: diff --git a/vllm_omni/diffusion/model_loader/diffusers_loader.py b/vllm_omni/diffusion/model_loader/diffusers_loader.py index daccef3439..3a5ac6ece5 100644 --- a/vllm_omni/diffusion/model_loader/diffusers_loader.py +++ b/vllm_omni/diffusion/model_loader/diffusers_loader.py @@ -6,6 +6,7 @@ import re import time from collections.abc import Generator, Iterable +from contextlib import nullcontext from pathlib import Path from typing import cast @@ -32,6 +33,14 @@ from vllm_omni.diffusion.data import OmniDiffusionConfig from vllm_omni.diffusion.distributed.hsdp import HSDPInferenceConfig from vllm_omni.diffusion.model_loader.gguf_adapters import get_gguf_adapter +from vllm_omni.diffusion.quantization.bitsandbytes import ( + DiffusionBitsAndBytesConfig, + _is_vllm_linear, + apply_bnb_quantization, + matches_bnb_module_name, + patch_transformers_for_bnb_load, + set_bnb_quantized_components, +) from vllm_omni.diffusion.registry import initialize_model logger = init_logger(__name__) @@ -265,24 +274,76 @@ def load_model( logger.info(f"Quantization enabled with CPU offload, using {load_device} for weight loading") target_device = torch.device(load_device) + enable_offload = bool( + getattr(od_config, "enable_cpu_offload", False) or getattr(od_config, "enable_layerwise_offload", False) + ) with set_default_torch_dtype(od_config.dtype): + bnb_config = ( + od_config.quantization_config + if isinstance(od_config.quantization_config, DiffusionBitsAndBytesConfig) + else None + ) if od_config.parallel_config.use_hsdp: model = self._load_model_with_hsdp( od_config, load_format=load_format, custom_pipeline_name=custom_pipeline_name ) else: with target_device: - if load_format == "default": - model = initialize_model(od_config) - elif load_format == "custom_pipeline": - model_cls = resolve_obj_by_qualname(custom_pipeline_name) - model = model_cls(od_config=od_config) - logger.debug("Loading weights on %s ...", load_device) - if self._is_gguf_quantization(od_config): - self._load_weights_with_gguf(model, od_config) - else: - # Quantization does not happen in `load_weights` but after it - self.load_weights(model) + enable_hf_bnb_load = bool( + bnb_config is not None and not enable_offload and target_device.type == "cuda" + ) + patch_context = ( + patch_transformers_for_bnb_load( + bnb_config, + device=target_device, + enable_cpu_offload=getattr(od_config, "enable_cpu_offload", False), + enable_hf_bnb_load=enable_hf_bnb_load, + ) + if bnb_config is not None + else nullcontext(set()) + ) + with patch_context as bnb_loaded_components: + if load_format == "default": + model = initialize_model(od_config) + elif load_format == "custom_pipeline": + model_cls = resolve_obj_by_qualname(custom_pipeline_name) + model = model_cls(od_config=od_config) + + # Pre-replace transformer Linear modules before loading weights to reduce peak memory. + if bnb_config is not None: + quantized_components = set(bnb_loaded_components) + pre_replace_modules = set() + sources = getattr(model, "weights_sources", ()) + for source in sources: + prefix = getattr(source, "prefix", "") + if prefix: + module_name = prefix.split(".", 1)[0] + if module_name: + pre_replace_modules.add(module_name) + if pre_replace_modules: + only_modules = [] + requested_modules = bnb_config.get_modules() + for module_name in pre_replace_modules: + if not matches_bnb_module_name(module_name, requested_modules): + continue + component = getattr(model, module_name, None) + if component is not None and _component_contains_vllm_linear(component): + continue + only_modules.append(module_name) + quantized_components |= apply_bnb_quantization( + model, + bnb_config, + copy_weights=False, + only_modules=only_modules, + ) + set_bnb_quantized_components(model, quantized_components) + + logger.debug("Loading weights on %s ...", load_device) + if self._is_gguf_quantization(od_config): + self._load_weights_with_gguf(model, od_config) + else: + # Quantization does not happen in `load_weights` but after it + self.load_weights(model) # Process weights after loading for quantization (e.g., FP8 online quantization) # This is needed for vLLM's quantization methods that need to transform weights @@ -497,3 +558,8 @@ def _load_model_with_hsdp( logger.debug("Applying HSDP to %s", name) apply_hsdp_to_model(trans, hsdp_config) return model + + +def _component_contains_vllm_linear(component: nn.Module) -> bool: + """Return True if component includes vLLM LinearBase modules.""" + return any(_is_vllm_linear(module) for module in component.modules()) diff --git a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py index b16d24bf7b..3ad14bff73 100644 --- a/vllm_omni/diffusion/models/z_image/pipeline_z_image.py +++ b/vllm_omni/diffusion/models/z_image/pipeline_z_image.py @@ -218,6 +218,12 @@ def encode_prompt( ) else: negative_prompt_embeds = [] + # Ensure prompt embeddings match transformer dtype for downstream linear ops. + transformer_dtype = getattr(self.transformer, "dtype", None) + if transformer_dtype is not None: + prompt_embeds = [pe.to(transformer_dtype) for pe in prompt_embeds] + if negative_prompt_embeds: + negative_prompt_embeds = [npe.to(transformer_dtype) for npe in negative_prompt_embeds] return prompt_embeds, negative_prompt_embeds def _encode_prompt( diff --git a/vllm_omni/diffusion/offloader/layerwise_backend.py b/vllm_omni/diffusion/offloader/layerwise_backend.py index 5b66ae5ee2..24350073c6 100644 --- a/vllm_omni/diffusion/offloader/layerwise_backend.py +++ b/vllm_omni/diffusion/offloader/layerwise_backend.py @@ -9,6 +9,7 @@ from vllm.logger import init_logger from vllm_omni.diffusion.hooks import HookRegistry, ModelHook +from vllm_omni.diffusion.quantization.bitsandbytes import get_bnb_offload_skip_components from vllm_omni.platforms import current_omni_platform from .base import OffloadBackend, OffloadConfig @@ -257,6 +258,21 @@ def enable(self, pipeline: nn.Module) -> None: logger.warning("No DiT/transformer modules found, skipping layer-wise offloading") return + skip_components = get_bnb_offload_skip_components(pipeline) + if skip_components: + logger.debug("Skipping offload for quantized components: %s", sorted(skip_components)) + + dit_modules = [] + dit_names = [] + for name, module in zip(modules.dit_names, modules.dits): + if name in skip_components: + continue + dit_modules.append(module) + dit_names.append(name) + if not dit_modules: + logger.warning("All DiT modules are quantized; skipping layer-wise offloading") + return + # Move encoders to GPU (they stay resident) for enc in modules.encoders: enc.to(self.device) @@ -268,12 +284,12 @@ def enable(self, pipeline: nn.Module) -> None: except Exception as exc: logger.debug("Failed to move VAE to GPU: %s", exc) - logger.info("Applying layer-wise offloading on %s", modules.dit_names) + logger.info("Applying layer-wise offloading on %s", dit_names) # Apply block-wise offloading hook for each of the blocks in DiT model(s) # Note that there might exist multiple DiT models in specific pipelines - for i, dit_module in enumerate(modules.dits): - dit_name = modules.dit_names[i] + for i, dit_module in enumerate(dit_modules): + dit_name = dit_names[i] logger.info(f"Applying hooks on {dit_name} ({dit_module.__class__.__name__})") blocks_attr_name = LayerWiseOffloadBackend.get_blocks_attr_name(dit_module) diff --git a/vllm_omni/diffusion/offloader/sequential_backend.py b/vllm_omni/diffusion/offloader/sequential_backend.py index ebf85624ff..61f379619c 100644 --- a/vllm_omni/diffusion/offloader/sequential_backend.py +++ b/vllm_omni/diffusion/offloader/sequential_backend.py @@ -6,6 +6,7 @@ from vllm.logger import init_logger from vllm_omni.diffusion.hooks import HookRegistry, ModelHook +from vllm_omni.diffusion.quantization.bitsandbytes import get_bnb_offload_skip_components from vllm_omni.platforms import current_omni_platform from .base import OffloadBackend, OffloadConfig @@ -107,6 +108,9 @@ def apply_sequential_offload( encoder_modules: list[nn.Module], device: torch.device, pin_memory: bool = True, + *, + offload_dit_modules: list[nn.Module] | None = None, + offload_encoder_modules: list[nn.Module] | None = None, ) -> None: """Apply sequential offloading hooks to DiT and encoder modules. @@ -119,6 +123,8 @@ def apply_sequential_offload( encoder_modules: Encoder modules to register hooks on device: Target GPU device for loading pin_memory: Whether to pin CPU memory for faster transfers + offload_dit_modules: Optional subset of DiT modules to offload + offload_encoder_modules: Optional subset of encoder modules to offload Example: >>> apply_sequential_offload( @@ -128,11 +134,14 @@ def apply_sequential_offload( ... ) >>> # Modules of pipeline now automatically swap between CPU and GPU """ + offload_dit_modules = offload_dit_modules if offload_dit_modules is not None else dit_modules + offload_encoder_modules = offload_encoder_modules if offload_encoder_modules is not None else encoder_modules + # Register hooks on DiT modules (offload encoders when DiT runs) for dit_mod in dit_modules: registry = HookRegistry.get_or_create(dit_mod) hook = SequentialOffloadHook( - offload_targets=encoder_modules, + offload_targets=offload_encoder_modules, device=device, pin_memory=pin_memory, ) @@ -143,7 +152,7 @@ def apply_sequential_offload( for enc in encoder_modules: registry = HookRegistry.get_or_create(enc) hook = SequentialOffloadHook( - offload_targets=dit_modules, + offload_targets=offload_dit_modules, device=device, pin_memory=pin_memory, ) @@ -191,6 +200,17 @@ def enable(self, pipeline: nn.Module) -> None: logger.warning("No encoder modules found, skipping model-level offloading") return + skip_components = get_bnb_offload_skip_components(pipeline) + offload_dits = [module for name, module in zip(modules.dit_names, modules.dits) if name not in skip_components] + offload_encoders = [ + module for name, module in zip(modules.encoder_names, modules.encoders) if name not in skip_components + ] + if skip_components: + logger.debug("Skipping offload for quantized components: %s", sorted(skip_components)) + if not offload_dits and not offload_encoders: + logger.warning("All offload candidates are quantized; skipping model-level offloading") + return + # Move encoders to GPU for enc in modules.encoders: enc.to(self.device) @@ -208,6 +228,8 @@ def enable(self, pipeline: nn.Module) -> None: encoder_modules=modules.encoders, device=self.device, pin_memory=self.config.pin_cpu_memory, + offload_dit_modules=offload_dits, + offload_encoder_modules=offload_encoders, ) # Track modules for cleanup diff --git a/vllm_omni/diffusion/quantization/__init__.py b/vllm_omni/diffusion/quantization/__init__.py index d297d51f18..29df4747e9 100644 --- a/vllm_omni/diffusion/quantization/__init__.py +++ b/vllm_omni/diffusion/quantization/__init__.py @@ -22,11 +22,17 @@ linear_layer = QKVParallelLinear(..., quant_config=vllm_config) """ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from vllm.logger import init_logger from .base import DiffusionQuantizationConfig +from .bitsandbytes import ( + DiffusionBitsAndBytesConfig, + apply_bnb_quantization, + get_bnb_module_kwargs, + patch_transformers_for_bnb_load, +) from .fp8 import DiffusionFp8Config from .gguf import DiffusionGgufConfig @@ -41,11 +47,53 @@ # To add a new method, create a new config class and register it here _QUANT_CONFIG_REGISTRY: dict[str, type[DiffusionQuantizationConfig]] = { "fp8": DiffusionFp8Config, + "bitsandbytes": DiffusionBitsAndBytesConfig, "gguf": DiffusionGgufConfig, } SUPPORTED_QUANTIZATION_METHODS = list(_QUANT_CONFIG_REGISTRY.keys()) +DIFFUSION_QUANTIZATION_ALIASES: dict[str, str] = { + "": "none", + "null": "none", + "none": "none", + "no": "none", + "false": "none", + "bnb_8bit": "bitsandbytes_8bit", + "bnb8": "bitsandbytes_8bit", + "bitsandbytes_8bit": "bitsandbytes_8bit", + "bitsandbytes8bit": "bitsandbytes_8bit", + "bnb_4bit": "bitsandbytes_4bit", + "bnb4": "bitsandbytes_4bit", + "bitsandbytes_4bit": "bitsandbytes_4bit", + "bitsandbytes4bit": "bitsandbytes_4bit", +} + + +def normalize_diffusion_quant_method( + method: str | None, + quant_kwargs: dict[str, Any] | None = None, +) -> str | None: + """Normalize diffusion quantization method names and aliases. + + When quant_kwargs is provided, bitsandbytes 4bit/8bit aliases are normalized + to "bitsandbytes" and the corresponding load_in_* flags are populated. + """ + if method is None: + return None + if isinstance(method, str): + method = method.strip().lower() + method = DIFFUSION_QUANTIZATION_ALIASES.get(method, method) + if method in ("none", ""): + return None + if method in ("bitsandbytes_8bit", "bitsandbytes_4bit"): + if quant_kwargs is not None: + quant_kwargs.setdefault("load_in_8bit", method.endswith("8bit")) + quant_kwargs.setdefault("load_in_4bit", method.endswith("4bit")) + return "bitsandbytes" + return method + return method + def get_diffusion_quant_config( quantization: str | None, @@ -74,10 +122,9 @@ def get_diffusion_quant_config( ignored_layers=["proj_out"], ) """ - if quantization is None or quantization.lower() == "none": + quantization = normalize_diffusion_quant_method(quantization, kwargs) + if quantization is None: return None - - quantization = quantization.lower() if quantization not in _QUANT_CONFIG_REGISTRY: raise ValueError( f"Unknown quantization method: {quantization!r}. Supported methods: {SUPPORTED_QUANTIZATION_METHODS}" @@ -109,9 +156,15 @@ def get_vllm_quant_config_for_layers( __all__ = [ "DiffusionQuantizationConfig", + "DiffusionBitsAndBytesConfig", "DiffusionFp8Config", + "apply_bnb_quantization", + "get_bnb_module_kwargs", + "patch_transformers_for_bnb_load", "DiffusionGgufConfig", "get_diffusion_quant_config", "get_vllm_quant_config_for_layers", "SUPPORTED_QUANTIZATION_METHODS", + "DIFFUSION_QUANTIZATION_ALIASES", + "normalize_diffusion_quant_method", ] diff --git a/vllm_omni/diffusion/quantization/bitsandbytes.py b/vllm_omni/diffusion/quantization/bitsandbytes.py new file mode 100644 index 0000000000..98e3f5676a --- /dev/null +++ b/vllm_omni/diffusion/quantization/bitsandbytes.py @@ -0,0 +1,811 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import contextvars +import fnmatch +import os +import threading +import weakref +from collections.abc import Iterable, Iterator, Mapping, Sequence +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Literal + +import torch +import torch.nn as nn +from vllm.logger import init_logger + +from .base import DiffusionQuantizationConfig + +BnbBackend = Literal["bitsandbytes_8bit", "bitsandbytes_4bit"] +DEFAULT_BNB_MODULES = ("transformer", "text_encoder*") + +logger = init_logger(__name__) + +_BNB_LOAD_CONTEXT: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( + "diffusion_bnb_load_context", + default=None, +) +_BNB_PATCH_LOCK = threading.Lock() +_BNB_PATCH_APPLIED = False +_BNB_PATCH_REFCOUNT = 0 +_BNB_PATCH_ORIG = None +_BNB_PATCH_TARGET = None + + +@dataclass +class _BnbPipelineState: + quantized_components: set[str] = field(default_factory=set) + offload_skip_components: set[str] = field(default_factory=set) + + +_BNB_PIPELINE_STATE: weakref.WeakKeyDictionary[Any, _BnbPipelineState] = weakref.WeakKeyDictionary() + + +def _get_bnb_pipeline_state(pipeline: Any) -> _BnbPipelineState: + try: + state = _BNB_PIPELINE_STATE.get(pipeline) + except TypeError: + state = getattr(pipeline, "_bnb_pipeline_state", None) + if state is None: + state = _BnbPipelineState() + try: + setattr(pipeline, "_bnb_pipeline_state", state) + except Exception: + pass + return state + if state is None: + state = _BnbPipelineState() + _BNB_PIPELINE_STATE[pipeline] = state + return state + + +def get_bnb_quantized_components(pipeline: Any) -> set[str]: + return set(_get_bnb_pipeline_state(pipeline).quantized_components) + + +def set_bnb_quantized_components(pipeline: Any, components: set[str] | Iterable[str]) -> None: + state = _get_bnb_pipeline_state(pipeline) + state.quantized_components = set(components) + + +def update_bnb_quantized_components(pipeline: Any, components: Iterable[str]) -> None: + state = _get_bnb_pipeline_state(pipeline) + state.quantized_components.update(components) + + +def get_bnb_offload_skip_components(pipeline: Any) -> set[str]: + return set(_get_bnb_pipeline_state(pipeline).offload_skip_components) + + +def set_bnb_offload_skip_components(pipeline: Any, components: Iterable[str]) -> None: + state = _get_bnb_pipeline_state(pipeline) + state.offload_skip_components = set(components) + + +def _normalize_modules(modules: Sequence[str] | str | None) -> list[str] | None: + if modules is None: + return None + if isinstance(modules, str): + items = [m.strip() for m in modules.split(",")] + else: + items = [str(m).strip() for m in modules] + return [m for m in items if m] + + +def matches_bnb_module_name(name: str, patterns: Sequence[str]) -> bool: + return any(fnmatch.fnmatchcase(name, pattern) for pattern in patterns) + + +def _resolve_module_patterns(available: Iterable[str], patterns: Sequence[str]) -> list[str]: + ordered_available = list(dict.fromkeys(available)) + matched: list[str] = [] + for pattern in patterns: + for name in ordered_available: + if fnmatch.fnmatchcase(name, pattern) and name not in matched: + matched.append(name) + return matched + + +def _get_pipeline_component_names(pipeline: Any) -> list[str]: + names: list[str] = [] + if isinstance(pipeline, nn.Module): + names.extend(pipeline._modules.keys()) + components = getattr(pipeline, "components", None) + if isinstance(components, Mapping): + for name in components.keys(): + if name not in names: + names.append(name) + return names + + +def _normalize_bnb_compute_dtype(value: torch.dtype | str | None) -> torch.dtype | None: + if value is None: + return None + if isinstance(value, torch.dtype): + return value + if isinstance(value, str): + dtype_str = value.strip().lower() + if dtype_str in ("", "auto"): + return None + dtype_map = { + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float16": torch.float16, + "fp16": torch.float16, + "half": torch.float16, + "float32": torch.float32, + "fp32": torch.float32, + "float": torch.float32, + } + if dtype_str not in dtype_map: + raise ValueError( + f"Unknown bnb_4bit_compute_dtype {value!r}. Supported: {sorted(dtype_map.keys()) + ['auto']}" + ) + return dtype_map[dtype_str] + raise TypeError(f"bnb_4bit_compute_dtype must be a torch.dtype, str, or None (got {type(value)!r})") + + +@dataclass +class DiffusionBitsAndBytesConfig(DiffusionQuantizationConfig): + """Diffusion bitsandbytes config aligned with vLLM bitsandbytes fields.""" + + load_in_8bit: bool = False + load_in_4bit: bool = True + bnb_4bit_compute_dtype: torch.dtype | str | None = "bfloat16" + bnb_4bit_quant_type: str = "nf4" + bnb_4bit_use_double_quant: bool = True + llm_int8_enable_fp32_cpu_offload: bool = False + llm_int8_has_fp16_weight: bool = False + modules: Sequence[str] | str | None = None + + def __post_init__(self) -> None: + if self.load_in_8bit and self.load_in_4bit: + # Prefer 8bit if both are set (avoid ambiguous defaults). + self.load_in_4bit = False + if not self.load_in_8bit and not self.load_in_4bit: + raise ValueError("bitsandbytes config requires load_in_8bit or load_in_4bit to be True") + self.bnb_4bit_compute_dtype = _normalize_bnb_compute_dtype(self.bnb_4bit_compute_dtype) + self.modules = _normalize_modules(self.modules) + self._vllm_config = None + + @classmethod + def get_name(cls) -> str: + return "bitsandbytes" + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_backend(self) -> BnbBackend: + return "bitsandbytes_8bit" if self.load_in_8bit else "bitsandbytes_4bit" + + def get_modules(self) -> list[str]: + if self.modules: + return list(self.modules) + return list(DEFAULT_BNB_MODULES) + + +def get_bnb_module_kwargs( + quant_config: DiffusionBitsAndBytesConfig | None, + module_name: str | None, + device: torch.device, + *, + enable_cpu_offload: bool | None = None, +) -> dict[str, Any]: + """Build kwargs for transformers.from_pretrained to enable bnb quant at load time.""" + if quant_config is None: + return {} + + if not module_name: + return {} + + if not matches_bnb_module_name(module_name, quant_config.get_modules()): + return {} + + if device.type != "cuda": + return {} + + if not torch.cuda.is_available(): + return {} + + try: + from transformers import BitsAndBytesConfig # type: ignore[import-not-found] + except Exception: + return {} + + backend = quant_config.get_backend() + if backend == "bitsandbytes_8bit": + bnb_config = BitsAndBytesConfig( + load_in_8bit=True, + llm_int8_enable_fp32_cpu_offload=bool(enable_cpu_offload), + llm_int8_has_fp16_weight=bool(quant_config.llm_int8_has_fp16_weight), + ) + else: + compute_dtype = quant_config.bnb_4bit_compute_dtype or torch.float32 + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_quant_type=quant_config.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=quant_config.bnb_4bit_use_double_quant, + llm_int8_enable_fp32_cpu_offload=bool(enable_cpu_offload), + llm_int8_has_fp16_weight=bool(quant_config.llm_int8_has_fp16_weight), + ) + + return { + "device_map": {"": str(device)}, + "low_cpu_mem_usage": True, + "quantization_config": bnb_config, + } + + +def _infer_bnb_module_name( + model_name_or_path: Any, + subfolder: str | None, +) -> str | None: + if subfolder: + return str(subfolder) + if isinstance(model_name_or_path, (str, os.PathLike)): + base = os.path.basename(str(model_name_or_path).rstrip("/")) + if matches_bnb_module_name(base, DEFAULT_BNB_MODULES): + return base + return None + + +def _ensure_transformers_bnb_patch() -> bool: + """Ensure a thread-safe from_pretrained hook is installed. + + The hook is a no-op unless a context var is set by + patch_transformers_for_bnb_load. + """ + global _BNB_PATCH_APPLIED, _BNB_PATCH_REFCOUNT, _BNB_PATCH_ORIG, _BNB_PATCH_TARGET + try: + from transformers.modeling_utils import PreTrainedModel # type: ignore[import-not-found] + except Exception: + return False + + with _BNB_PATCH_LOCK: + if _BNB_PATCH_APPLIED: + _BNB_PATCH_REFCOUNT += 1 + return True + + orig_attr = PreTrainedModel.__dict__.get("from_pretrained") + if orig_attr is None: + return False + + orig_func = orig_attr.__func__ + + def _wrapped_from_pretrained(cls, model_name_or_path, *args, **kwargs): # type: ignore[no-untyped-def] + ctx = _BNB_LOAD_CONTEXT.get() + if ctx is None: + return orig_func(cls, model_name_or_path, *args, **kwargs) + + module_name = _infer_bnb_module_name(model_name_or_path, kwargs.get("subfolder")) + bnb_kwargs = get_bnb_module_kwargs( + ctx["quant_config"], + module_name, + ctx["device"], + enable_cpu_offload=ctx.get("enable_cpu_offload"), + ) + if bnb_kwargs: + merged_kwargs = dict(kwargs) + for key, value in bnb_kwargs.items(): + merged_kwargs.setdefault(key, value) + kwargs = merged_kwargs + if module_name is not None: + ctx["quantized_components"].add(module_name) + return orig_func(cls, model_name_or_path, *args, **kwargs) + + PreTrainedModel.from_pretrained = classmethod(_wrapped_from_pretrained) + _BNB_PATCH_APPLIED = True + _BNB_PATCH_REFCOUNT = 1 + _BNB_PATCH_ORIG = orig_attr + _BNB_PATCH_TARGET = PreTrainedModel + return True + + +def _release_transformers_bnb_patch() -> None: + global _BNB_PATCH_APPLIED, _BNB_PATCH_REFCOUNT, _BNB_PATCH_ORIG, _BNB_PATCH_TARGET + with _BNB_PATCH_LOCK: + if not _BNB_PATCH_APPLIED: + return + _BNB_PATCH_REFCOUNT = max(0, _BNB_PATCH_REFCOUNT - 1) + if _BNB_PATCH_REFCOUNT > 0: + return + if _BNB_PATCH_TARGET is not None and _BNB_PATCH_ORIG is not None: + _BNB_PATCH_TARGET.from_pretrained = _BNB_PATCH_ORIG + _BNB_PATCH_APPLIED = False + _BNB_PATCH_ORIG = None + _BNB_PATCH_TARGET = None + + +@contextmanager +def patch_transformers_for_bnb_load( + quant_config: DiffusionBitsAndBytesConfig | None, + *, + device: torch.device, + enable_cpu_offload: bool | None = None, + enable_hf_bnb_load: bool = True, +) -> Iterator[set[str]]: + """Temporarily inject bitsandbytes kwargs into transformers.from_pretrained. + + Returns a set of component names that were loaded with bitsandbytes config. + """ + if quant_config is None or not enable_hf_bnb_load: + yield set() + return + + if device.type != "cuda": + yield set() + return + + if not _ensure_transformers_bnb_patch(): + yield set() + return + + quantized_components: set[str] = set() + ctx = { + "quant_config": quant_config, + "device": device, + "enable_cpu_offload": enable_cpu_offload, + "quantized_components": quantized_components, + } + token = _BNB_LOAD_CONTEXT.set(ctx) + try: + yield quantized_components + finally: + _BNB_LOAD_CONTEXT.reset(token) + _release_transformers_bnb_patch() + + +def apply_bnb_quantization( + pipeline: nn.Module, + quant_config: DiffusionBitsAndBytesConfig | None, + *, + copy_weights: bool = True, + only_modules: Iterable[str] | None = None, + skip_modules: Iterable[str] | None = None, +) -> set[str]: + """Apply bitsandbytes weight-only quantization to selected pipeline components. + + This function is best-effort: + - Only replaces `torch.nn.Linear` modules (MVP scope). + - Skips missing component names in the configured module list. + + Returns: + Set of component names that were quantized. + """ + + if quant_config is None: + return set() + quant_backend = quant_config.get_backend() + + try: + import bitsandbytes as bnb # type: ignore[import-not-found] + except ImportError as exc: + raise ImportError( + "bitsandbytes is required for diffusion quantization='bitsandbytes'. " + "Install with: `pip install bitsandbytes`." + ) from exc + + if quant_config.llm_int8_enable_fp32_cpu_offload: + logger.warning_once( + "llm_int8_enable_fp32_cpu_offload only applies to HF load-time quantization; " + "it is ignored for post-hoc quantization." + ) + + requested_modules = quant_config.get_modules() + available_modules = list(only_modules) if only_modules is not None else _get_pipeline_component_names(pipeline) + quant_modules = _resolve_module_patterns(available_modules, requested_modules) + if skip_modules is not None: + skip = set(skip_modules) + quant_modules = [m for m in quant_modules if m not in skip] + if not quant_modules: + if only_modules is None: + logger.warning_once( + "bitsandbytes: none of the configured modules were found on the pipeline (%s).", + tuple(requested_modules), + ) + return set() + + logger.info_once("Applying bitsandbytes quantization to modules=%s", tuple(quant_modules)) + + bnb_compute_dtype = quant_config.bnb_4bit_compute_dtype or torch.float32 + quantized_components: set[str] = set() + replaced_any = False + + for module_name in quant_modules: + component = getattr(pipeline, module_name, None) + if component is None: + continue + if not isinstance(component, nn.Module): + continue + already_quantized = _contains_bnb_linear(component, bnb) + num_replaced = _apply_to_component( + component, + bnb=bnb, + backend=quant_backend, + bnb_4bit_quant_type=quant_config.bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_compute_dtype, + bnb_4bit_use_double_quant=quant_config.bnb_4bit_use_double_quant, + llm_int8_has_fp16_weight=quant_config.llm_int8_has_fp16_weight, + copy_weights=copy_weights, + ) + if num_replaced > 0: + quantized_components.add(module_name) + replaced_any = True + elif already_quantized: + quantized_components.add(module_name) + replaced_any = True + else: + logger.warning_once( + "bitsandbytes: no Linear layers replaced in module '%s' (%s).", + module_name, + component.__class__.__name__, + ) + + if not replaced_any: + logger.warning_once("bitsandbytes: no Linear layers replaced; quantization may be ineffective.") + + return quantized_components + + +def _apply_to_component( + component: nn.Module, + *, + bnb: Any, + backend: BnbBackend, + bnb_4bit_quant_type: str, + bnb_4bit_compute_dtype: torch.dtype | str | None, + bnb_4bit_use_double_quant: bool, + llm_int8_has_fp16_weight: bool, + copy_weights: bool, +) -> int: + original_device = _get_module_device(component) + + # Linear4bit requires fp weights before .to("cuda") triggers internal packing. + # We avoid migrating the whole component; per-layer CPU copies are handled in _load_linear_weights. + + num_replaced = _replace_linear_modules_inplace( + component, + bnb=bnb, + backend=backend, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + copy_weights=copy_weights, + ) + logger.info("Quantized %d Linear layers with %s in %s", num_replaced, backend, component.__class__.__name__) + if original_device is not None and original_device.type != "cpu": + component.to(original_device) + return num_replaced + + +def _replace_linear_modules_inplace( + root: nn.Module, + *, + bnb: Any, + backend: BnbBackend, + bnb_4bit_quant_type: str, + bnb_4bit_compute_dtype: torch.dtype | str | None, + bnb_4bit_use_double_quant: bool, + llm_int8_has_fp16_weight: bool, + copy_weights: bool, +) -> int: + replaced = 0 + for child_name, child in list(root.named_children()): + if _is_bnb_linear(child, bnb): + continue + if _is_vllm_linear(child) and backend == "bitsandbytes_4bit" and copy_weights: + if _quantize_vllm_linear_inplace( + child, + bnb=bnb, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + ): + replaced += 1 + continue + if _is_supported_linear(child, copy_weights=copy_weights): + new_child = _convert_linear_module( + child, + bnb=bnb, + backend=backend, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + copy_weights=copy_weights, + ) + _set_child_module(root, child_name, new_child) + replaced += 1 + continue + + replaced += _replace_linear_modules_inplace( + child, + bnb=bnb, + backend=backend, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + copy_weights=copy_weights, + ) + + return replaced + + +class _DiffusionBnbLinearMethod: + """Minimal bnb 4bit method for vLLM Linear modules.""" + + def __init__(self, compute_dtype: torch.dtype | None) -> None: + self.compute_dtype = compute_dtype + + def apply( + self, + layer: nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + from bitsandbytes import matmul_4bit + + original_type = x.dtype + original_shape = x.shape + reshape_after_matmul = False + if x.ndim > 2: + x = x.reshape(-1, x.size(-1)) + reshape_after_matmul = True + + compute_dtype = self.compute_dtype or x.dtype + if compute_dtype != x.dtype: + x = x.to(compute_dtype) + + weight = layer.weight + out = matmul_4bit(x, weight.t(), weight.quant_state) + if out.dtype != original_type: + out = out.to(original_type) + + if reshape_after_matmul: + out = out.view(*original_shape[:-1], out.size(-1)) + + if bias is not None: + out += bias + return out + + +def _quantize_vllm_linear_inplace( + linear: nn.Module, + *, + bnb: Any, + bnb_4bit_quant_type: str, + bnb_4bit_compute_dtype: torch.dtype | str | None, + bnb_4bit_use_double_quant: bool, +) -> bool: + if getattr(linear, "tp_size", 1) != 1: + return False + weight = getattr(linear, "weight", None) + if weight is None: + return False + if getattr(weight, "quant_state", None) is not None: + return False + original_device = weight.device + if original_device.type != "cuda": + return False + + in_features, out_features = _get_linear_io_features(linear) + compute_dtype = bnb_4bit_compute_dtype or torch.float32 + + temp = bnb.nn.Linear4bit( + in_features, + out_features, + bias=False, + compute_dtype=compute_dtype, + compress_statistics=bnb_4bit_use_double_quant, + quant_type=bnb_4bit_quant_type, + device=torch.device("cpu"), + ) + _load_linear_weights(linear, temp, include_bias=False) + temp = temp.to(original_device) + + linear._parameters["weight"] = temp.weight + linear.quant_method = _DiffusionBnbLinearMethod( + compute_dtype=getattr(temp, "compute_dtype", compute_dtype), + ) + return True + + +def _convert_linear_module( + linear: nn.Module, + *, + bnb: Any, + backend: BnbBackend, + bnb_4bit_quant_type: str, + bnb_4bit_compute_dtype: torch.dtype | str | None, + bnb_4bit_use_double_quant: bool, + llm_int8_has_fp16_weight: bool, + copy_weights: bool, +) -> nn.Module: + original_device = linear.weight.device + in_features, out_features = _get_linear_io_features(linear) + bias = getattr(linear, "bias", None) + has_bias = bias is not None + is_vllm_linear = _is_vllm_linear(linear) + return_bias = bool(getattr(linear, "return_bias", False)) + skip_bias_add = bool(getattr(linear, "skip_bias_add", False)) + + if copy_weights: + target_device = torch.device("cpu") + else: + target_device = original_device + + # Bias handling truth table (only affects inner module bias allocation): + # - return_bias=False, skip_bias_add=False -> keep bias + # - return_bias=False, skip_bias_add=True -> keep bias (caller ignores) + # - return_bias=True, skip_bias_add=False -> keep bias (bias added inside) + # - return_bias=True, skip_bias_add=True -> drop bias (bias returned separately) + inner_has_bias = has_bias if not (is_vllm_linear and return_bias and skip_bias_add) else False + if backend == "bitsandbytes_8bit": + new_linear = bnb.nn.Linear8bitLt( + in_features, + out_features, + bias=inner_has_bias, + has_fp16_weights=llm_int8_has_fp16_weight, + device=target_device, + ) + elif backend == "bitsandbytes_4bit": + new_linear = bnb.nn.Linear4bit( + in_features, + out_features, + bias=inner_has_bias, + compute_dtype=bnb_4bit_compute_dtype, + compress_statistics=bnb_4bit_use_double_quant, + quant_type=bnb_4bit_quant_type, + device=target_device, + ) + else: + raise ValueError(f"Unknown backend: {backend}") + + if copy_weights: + _load_linear_weights(linear, new_linear, include_bias=inner_has_bias) + if original_device.type != "cpu": + new_linear = new_linear.to(original_device) + + if is_vllm_linear and return_bias: + bias_param = None + if has_bias and skip_bias_add: + bias_param = bias.detach().clone() + return _BnbLinearReturnBiasWrapper( + linear=new_linear, + bias=bias_param, + return_bias=return_bias, + skip_bias_add=skip_bias_add, + meta=_collect_vllm_linear_meta(linear), + ) + + return new_linear + + +def _is_bnb_linear(module: nn.Module, bnb: Any) -> bool: + return isinstance(module, (bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)) + + +def _contains_bnb_linear(module: nn.Module, bnb: Any) -> bool: + return any(_is_bnb_linear(child, bnb) for child in module.modules()) + + +def _is_vllm_linear(module: nn.Module) -> bool: + module_path = getattr(module.__class__, "__module__", "") + if module_path.startswith("vllm.model_executor.layers.linear"): + return True + try: + from vllm.model_executor.layers.linear import LinearBase + except Exception: + return False + return isinstance(module, LinearBase) + + +def _is_supported_linear(module: nn.Module, *, copy_weights: bool) -> bool: + if isinstance(module, nn.Linear): + return True + if not _is_vllm_linear(module): + return False + # Avoid pre-replace for vLLM linear modules to keep weight loading safe. + if not copy_weights: + return False + tp_size = getattr(module, "tp_size", 1) + if tp_size != 1: + return False + try: + state_keys = set(module.state_dict().keys()) + except Exception: + return False + return state_keys.issubset({"weight", "bias"}) + + +def _get_linear_io_features(module: nn.Module) -> tuple[int, int]: + if hasattr(module, "in_features") and hasattr(module, "out_features"): + return int(module.in_features), int(module.out_features) + if hasattr(module, "input_size") and hasattr(module, "output_size"): + return int(module.input_size), int(module.output_size) + weight = getattr(module, "weight", None) + if weight is not None and hasattr(weight, "shape") and len(weight.shape) >= 2: + return int(weight.shape[1]), int(weight.shape[0]) + raise ValueError(f"Cannot infer linear features for module {module.__class__.__name__}") + + +def _load_linear_weights(src: nn.Module, dst: nn.Module, *, include_bias: bool = True) -> None: + # Loading via state_dict ensures bitsandbytes hooks see fp weights, + # and .to("cuda") triggers internal quantization. + state = src.state_dict() + keys = {"weight", "bias"} if include_bias else {"weight"} + filtered = {k: v for k, v in state.items() if k in keys} + if any(getattr(t, "is_cuda", False) for t in filtered.values()): + filtered = {k: v.detach().cpu() for k, v in filtered.items()} + # bitsandbytes 4bit is sensitive to bf16 weights; cast to compute dtype if set. + compute_dtype = getattr(dst, "compute_dtype", None) + if compute_dtype is not None: + filtered = {k: (v.to(compute_dtype) if torch.is_floating_point(v) else v) for k, v in filtered.items()} + dst.load_state_dict(filtered, strict=False) + + +def _collect_vllm_linear_meta(module: nn.Module) -> dict[str, object]: + meta = {} + for name in ("num_heads", "num_kv_heads", "head_dim", "tp_size", "tp_rank", "input_size", "output_size"): + if hasattr(module, name): + meta[name] = getattr(module, name) + return meta + + +class _BnbLinearReturnBiasWrapper(nn.Module): + """Wrapper to preserve (output, bias) semantics for vLLM Linear modules.""" + + def __init__( + self, + *, + linear: nn.Module, + bias: torch.Tensor | None, + return_bias: bool, + skip_bias_add: bool, + meta: dict[str, object] | None = None, + ) -> None: + super().__init__() + self.linear = linear + self.return_bias = return_bias + self.skip_bias_add = skip_bias_add + if bias is None: + self.register_parameter("bias", None) + else: + self.bias = nn.Parameter(bias, requires_grad=False) + for key, value in (meta or {}).items(): + setattr(self, key, value) + + def forward(self, x: torch.Tensor): + out = self.linear(x) + if not self.return_bias: + return out + return out, (self.bias if self.skip_bias_add else None) + + +def _get_module_device(module: nn.Module) -> torch.device | None: + try: + param = next(module.parameters()) + return param.device + except StopIteration: + try: + buf = next(module.buffers()) + return buf.device + except StopIteration: + return None + + +def _set_child_module(parent: nn.Module, name: str, child: nn.Module) -> None: + if isinstance(parent, nn.ModuleList): + parent[int(name)] = child + return + if isinstance(parent, nn.ModuleDict): + parent[name] = child + return + setattr(parent, name, child) diff --git a/vllm_omni/diffusion/worker/diffusion_model_runner.py b/vllm_omni/diffusion/worker/diffusion_model_runner.py index accb173e1a..2bb2b38055 100644 --- a/vllm_omni/diffusion/worker/diffusion_model_runner.py +++ b/vllm_omni/diffusion/worker/diffusion_model_runner.py @@ -27,6 +27,13 @@ from vllm_omni.diffusion.forward_context import set_forward_context from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.offloader import get_offload_backend +from vllm_omni.diffusion.quantization.bitsandbytes import ( + DiffusionBitsAndBytesConfig, + apply_bnb_quantization, + get_bnb_quantized_components, + set_bnb_offload_skip_components, + update_bnb_quantized_components, +) from vllm_omni.diffusion.registry import _NO_CACHE_ACCELERATION from vllm_omni.diffusion.request import OmniDiffusionRequest from vllm_omni.distributed.omni_connectors.kv_transfer_manager import OmniKVTransferManager @@ -138,6 +145,32 @@ def get_memory_context(): ) logger.info("Model runner: Model loaded successfully.") + # Optional diffusion quantization (e.g., bitsandbytes). + bnb_config = ( + self.od_config.quantization_config + if isinstance(self.od_config.quantization_config, DiffusionBitsAndBytesConfig) + else None + ) + if bnb_config is not None: + # Run a post-load pass to quantize any components that weren't + # covered by load-time bnb injection (or were loaded later). + skip_modules = get_bnb_quantized_components(self.pipeline) + quantized_now = apply_bnb_quantization(self.pipeline, bnb_config, skip_modules=skip_modules) + if quantized_now: + update_bnb_quantized_components(self.pipeline, quantized_now) + offload_requested = bool( + getattr(self.od_config, "enable_cpu_offload", False) + or getattr(self.od_config, "enable_layerwise_offload", False) + ) + if offload_requested: + if bnb_config.load_in_8bit: + skip_components = get_bnb_quantized_components(self.pipeline) + set_bnb_offload_skip_components(self.pipeline, skip_components) + logger.warning("bitsandbytes 8bit + offload: disabled for quantized components (stability).") + else: + set_bnb_offload_skip_components(self.pipeline, set()) + logger.info("bitsandbytes 4bit + offload: enabled.") + # Apply CPU offloading self.offload_backend = get_offload_backend(self.od_config, device=self.device) if self.offload_backend is not None: diff --git a/vllm_omni/entrypoints/async_omni.py b/vllm_omni/entrypoints/async_omni.py index 474834d16a..248feb50d7 100644 --- a/vllm_omni/entrypoints/async_omni.py +++ b/vllm_omni/entrypoints/async_omni.py @@ -158,6 +158,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st # TODO: here is different from the Omni class. We should merge the two in the future. cache_backend = kwargs.get("cache_backend", "none") cache_config = self._normalize_cache_config(cache_backend, kwargs.get("cache_config", None)) + quantization_config = self._normalize_quantization_config(kwargs.get("quantization_config", None)) devices = "0" if "parallel_config" in kwargs: @@ -226,6 +227,7 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> dict[st "diffusion_load_format": kwargs.get("diffusion_load_format", "default"), "custom_pipeline_args": kwargs.get("custom_pipeline_args", None), "quantization": kwargs.get("quantization", None), + "quantization_config": quantization_config, "worker_extension_cls": kwargs.get("worker_extension_cls", None), "enable_sleep_mode": kwargs.get("enable_sleep_mode", False), "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index 7365e73bb1..184738c9be 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -276,6 +276,26 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu help="Enable cache-dit summary logging after diffusion forward passes.", ) + # Quantization parameters (diffusion-only) + omni_config_group.add_argument( + "--diffusion-quantization", + type=str, + default=None, + dest="quantization", + help="Diffusion quantization method (e.g., bitsandbytes).", + ) + omni_config_group.add_argument( + "--diffusion-quantization-config", + type=str, + default=None, + dest="quantization_config", + help=( + "JSON string for diffusion quantization config. For bitsandbytes, " + 'use fields like \'{"load_in_4bit": true, "bnb_4bit_compute_dtype": "bfloat16", ' + '"bnb_4bit_quant_type": "nf4", "bnb_4bit_use_double_quant": true}\'.' + ), + ) + # VAE memory optimization parameters omni_config_group.add_argument( "--vae-use-slicing", diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 488b986d8e..be3b612bd2 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -23,6 +23,7 @@ from vllm_omni.config.stage_config import StageConfigFactory from vllm_omni.config.yaml_util import create_config +from vllm_omni.diffusion.quantization import get_diffusion_quant_config, normalize_diffusion_quant_method from vllm_omni.distributed.omni_connectors import ( get_stage_connector_config, initialize_orchestrator_connectors, @@ -225,6 +226,15 @@ def _normalize_cache_config(self, cache_backend: str | None, cache_config: Any | cache_config = self._get_default_cache_config(cache_backend) return cache_config + def _normalize_quantization_config(self, quantization_config: Any | None) -> Any | None: + if isinstance(quantization_config, str): + try: + quantization_config = json.loads(quantization_config) + except json.JSONDecodeError: + logger.warning("Invalid quantization_config JSON, disabling quantization.") + quantization_config = None + return quantization_config + def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> list[dict[str, Any]]: """Create default diffusion stage configuration. @@ -246,6 +256,29 @@ def _create_default_diffusion_stage_cfg(self, kwargs: dict[str, Any]) -> list[di # Normalize cache config before passing to factory cache_backend = kwargs.get("cache_backend", "none") cache_config = self._normalize_cache_config(cache_backend, kwargs.get("cache_config", None)) + quantization_config = self._normalize_quantization_config(kwargs.get("quantization_config", None)) + quant_method = kwargs.get("quantization", None) + if isinstance(quantization_config, dict) and "method" in quantization_config: + quant_method = quantization_config.get("method") + if isinstance(quant_method, str): + normalized = normalize_diffusion_quant_method(quant_method) + if normalized is None: + logger.warning("Invalid diffusion quantization method %r; disabling quantization.", quant_method) + kwargs["quantization"] = None + if isinstance(quantization_config, dict): + quantization_config.pop("method", None) + else: + try: + get_diffusion_quant_config(normalized) + except ValueError: + logger.warning( + "Unsupported diffusion quantization method %r; disabling quantization.", quant_method + ) + kwargs["quantization"] = None + if isinstance(quantization_config, dict): + quantization_config.pop("method", None) + if quantization_config is not None: + kwargs["quantization_config"] = quantization_config # Update kwargs with normalized values kwargs_copy = dict(kwargs)