diff --git a/docs/user_guide/diffusion/quantization/msmodelslim.md b/docs/user_guide/diffusion/quantization/msmodelslim.md new file mode 100644 index 00000000000..5492cd9272b --- /dev/null +++ b/docs/user_guide/diffusion/quantization/msmodelslim.md @@ -0,0 +1,56 @@ +# msModelSlim Quantization + +## Overview + +[msModelSlim](https://github.com/Ascend/msmodelslim) is an Ascend-friendly compression tool focused on acceleration, using compression techniques, and built for Ascend hardware. It includes a series of inference optimization technologies such as quantization and compression, aiming to accelerate large language dense models, MoE models, multimodal understanding models, multimodal generation models, etc. + +Once you have a quantized model which is generated by **msModelSlim**, you can use vLLM Omni for inference by specifying the --quantization ascend parameter to enable quantization features. + +### Supported Schemes + +| Scheme | Bits | Status | +|--------|------|--------| +| W8A8 | 8 | ✅ Supported | +| W4A4 | 4 | Planned | + +W8A8 is the first supported scheme. Additional schemes will be added in future releases. + +## Model Quantization + +The following example shows how to generate W8A8 quantized weights for the [Wan2_2 model](https://gitcode.com/Ascend/msmodelslim/blob/master/example/multimodal_sd/Wan2_2/README.md). + +**Quantization Script:** + +```bash +msmodelslim quant \ + --model_path /path/to/wan2_2_t2v_float_weights \ + --save_path /path/to/wan2_2_t2v_quantized_weights \ + --device npu \ + --model_type Wan2_2 \ + --config_path /lab_practice/wan2_2/wan2_2_w8a8f8_mxfp_t2v.yaml \ + --trust_remote_code True +``` + +After quantization completes, the output directory will contain the quantized model files. + +For more examples, refer to the [official examples](https://gitcode.com/Ascend/msit/tree/master/msmodelslim/example). + +## Configuration + +1. **CLI**: pass `--quantization ascend`. + +```bash +# Offline inference +python text_to_image.py --model --quantization ascend + +# Online serving +vllm serve --omni --quantization ascend +``` + +## Supported Models + +| Model | HF Models | Recommendation | `ignored_layers` | +|-------|-----------|---------------|------------------| +| HunyuanImage-3.0 | - | All layers | None | + +Currently, quantized HunyuanImage-3.0 weights have not been uploaded to public model platforms such as Hugging Face. You can use a [HunyuanImage-3.0-adapted msModelSlim version](https://gitcode.com/betta18/msmodelslim/tree/hyimage3_mxfp8) to generate the quantized weights manually. We will upload the quantized weights as soon as possible. diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 0f9f10dce97..bc18c685912 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -300,6 +300,7 @@ def parse_args() -> argparse.Namespace: default=None, help=("Custom system prompt. Used when --use-system-prompt is custom. "), ) + current_omni_platform.pre_register_and_update(parser) from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults nullify_stage_engine_defaults(parser) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 40c943a41cd..cf6841fd21d 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -759,7 +759,7 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig": # Backwards-compatibility: map "quantization" to "quantization_config" # so callers using the old field name still work. - if "quantization" in kwargs and kwargs.get("quantization_config") is None: + if "quantization" in kwargs and kwargs.get("quantization_config", None) is None: kwargs["quantization_config"] = kwargs.pop("quantization") else: kwargs.pop("quantization", None) diff --git a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py index fbdacddaf34..0f3c33389c5 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/hunyuan_image3_transformer.py @@ -1484,7 +1484,7 @@ def __init__( config.hidden_size, config.num_experts, bias=False, - quant_config=None, + quant_config=quant_config, prefix=f"{prefix}.gate", ) if config.use_mixed_mlp_moe > 0: @@ -1658,8 +1658,10 @@ def forward( custom_pos_emb: tuple[torch.FloatTensor] | None = None, **kwargs, ) -> torch.Tensor: - bsz, q_len, _ = hidden_states.size() + bsz, q_len, hidden_size = hidden_states.size() + hidden_states = hidden_states.reshape(-1, hidden_size) qkv, _ = self.qkv_proj(hidden_states) + qkv = qkv.reshape(bsz, q_len, -1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) past_key_value: Cache | None = kwargs.get("past_key_value", None) @@ -1723,7 +1725,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - quant_config=None, + quant_config=quant_config, bias=attention_bias, cache_config=None, prefix=f"{prefix}.self_attn", @@ -1933,7 +1935,7 @@ def __init__(self, config: HunyuanImage3Config, quant_config=None, prefix: str = layer_idx=int(prefix.split(".")[-1]), prefix=prefix, ), - prefix=f"{prefix}.layers", + prefix=f"{prefix}.layers" if prefix else "layers", ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1948,7 +1950,7 @@ def _split_qkv_weight(self, qkv: torch.Tensor): num_attention_heads = self.config.num_attention_heads num_kv_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) num_key_value_groups = num_attention_heads // num_kv_heads - hidden_size = self.config.hidden_size + hidden_size = qkv.shape[1] if hasattr(self.config, "head_dim"): attention_head_dim = self.config.head_dim @@ -2001,8 +2003,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): split_params_mapping = [ (".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None), ( - ".qkv_proj", - ".qkv_proj", + ".qkv_proj.weight", + ".qkv_proj.weight", + num_attention_heads + num_kv_heads * 2, + [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)], + self._split_qkv_weight, + ), + ( + ".qkv_proj.weight_scale", + ".qkv_proj.weight_scale", num_attention_heads + num_kv_heads * 2, [("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)], self._split_qkv_weight, @@ -2101,6 +2110,8 @@ def contains_unexpected_keyword(name, keywords): continue if "mlp.experts" in name: continue + if ".qkv_proj" in name and not name.endswith(weight_name): + continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py index 3de0ab31016..84a7787ad11 100644 --- a/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py +++ b/vllm_omni/diffusion/models/hunyuan_image3/pipeline_hunyuan_image3.py @@ -14,7 +14,7 @@ from transformers.models.siglip2 import Siglip2VisionConfig, Siglip2VisionModel from transformers.utils.generic import ModelOutput from vllm.config.vllm import get_current_vllm_config -from vllm.model_executor.models.utils import AutoWeightsLoader +from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper from vllm.transformers_utils.config import get_config from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig @@ -64,6 +64,15 @@ def to_device(data, device): class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.": "", + }, + orig_to_new_substr={ + "mlp.gate.wg.": "mlp.gate.", + "gate_and_up_proj.": "gate_up_proj.", + }, + ) _PROFILER_TARGETS = [ "model.forward", "model.layers[0].forward", diff --git a/vllm_omni/diffusion/registry.py b/vllm_omni/diffusion/registry.py index ed7a03c2b11..ebcd263c143 100644 --- a/vllm_omni/diffusion/registry.py +++ b/vllm_omni/diffusion/registry.py @@ -5,6 +5,7 @@ import torch.nn as nn from vllm.logger import init_logger +from vllm.model_executor.model_loader.utils import configure_quant_config from vllm.model_executor.models.registry import _LazyRegisteredModel, _ModelRegistry from vllm_omni.diffusion.data import OmniDiffusionConfig @@ -13,6 +14,7 @@ from vllm_omni.diffusion.forward_context import get_forward_context from vllm_omni.diffusion.hooks.sequence_parallel import apply_sequence_parallel from vllm_omni.diffusion.utils.tf_utils import find_module_with_attr +from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) @@ -242,6 +244,22 @@ } +def _prepare_diffusion_quant_config( + od_config: OmniDiffusionConfig, + model_class: type[nn.Module], +) -> None: + """Prepare diffusion quant config using vLLM-style model bindings.""" + quant_config = od_config.quantization_config + if quant_config is None: + return + if hasattr(quant_config, "maybe_update_config"): + quant_config.maybe_update_config(od_config.model) + diffusion_packed_modules_mapping = current_omni_platform.get_diffusion_packed_modules_mapping(model_class) + if diffusion_packed_modules_mapping is not None: + model_class.packed_modules_mapping = diffusion_packed_modules_mapping + configure_quant_config(quant_config, model_class) + + def initialize_model( od_config: OmniDiffusionConfig, ) -> nn.Module: @@ -264,6 +282,7 @@ def initialize_model( """ model_class = DiffusionModelRegistry._try_load_model_cls(od_config.model_class_name) if model_class is not None: + _prepare_diffusion_quant_config(od_config, model_class) model = model_class(od_config=od_config) vae_pp_size = od_config.parallel_config.vae_patch_parallel_size diff --git a/vllm_omni/diffusion/worker/diffusion_worker.py b/vllm_omni/diffusion/worker/diffusion_worker.py index 621f7eebd86..927bbeb1a2a 100644 --- a/vllm_omni/diffusion/worker/diffusion_worker.py +++ b/vllm_omni/diffusion/worker/diffusion_worker.py @@ -13,6 +13,7 @@ import os from collections.abc import Iterable from contextlib import AbstractContextManager, nullcontext +from types import SimpleNamespace from typing import Any import torch @@ -21,6 +22,7 @@ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.logger import init_logger from vllm.profiler.wrapper import CudaProfilerWrapper, WorkerProfiler +from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.mem_utils import GiB_bytes from vllm.v1.worker.workspace import init_workspace_manager @@ -120,6 +122,20 @@ def init_device(self) -> None: vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel vllm_config.profiler_config = self.od_config.profiler_config + try: + hf_config = get_config(self.od_config.model, trust_remote_code=self.od_config.trust_remote_code) + except ValueError: + hf_config = None + logger.info("Skipping hf_config loading for diffusion model %r", self.od_config.model_class_name) + hf_text_config = get_hf_text_config(hf_config) if hf_config is not None else None + vllm_config.model_config = SimpleNamespace( + hf_config=hf_config, + hf_text_config=hf_text_config, + enforce_eager=self.od_config.enforce_eager, + dtype=self.od_config.dtype, + enable_return_routed_experts=False, + ) + vllm_config.quant_config = self.od_config.quantization_config self.vllm_config = vllm_config # Initialize distributed environment diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 6ddaf993e19..8299a577a41 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1451,16 +1451,17 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st if lora_scale is not None: if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None: cfg.engine_args.lora_scale = lora_scale - # Prefer explicit quantization_config; fallback to legacy --quantization. quantization_config = kwargs.get("quantization_config") - if quantization_config is None: - quantization_config = kwargs.get("quantization") if quantization_config is not None: if ( not hasattr(cfg.engine_args, "quantization_config") or cfg.engine_args.quantization_config is None ): cfg.engine_args.quantization_config = quantization_config + quantization = kwargs.get("quantization") + if quantization is not None: + if not hasattr(cfg.engine_args, "quantization") or cfg.engine_args.quantization is None: + cfg.engine_args.quantization = quantization except Exception as e: logger.warning("Failed to inject LoRA config for stage: %s", e) diff --git a/vllm_omni/platforms/interface.py b/vllm_omni/platforms/interface.py index b69731a67d5..11eec76acdf 100644 --- a/vllm_omni/platforms/interface.py +++ b/vllm_omni/platforms/interface.py @@ -6,6 +6,7 @@ from typing import Any import torch +import torch.nn as nn from vllm.logger import init_logger from vllm.platforms import Platform @@ -71,6 +72,13 @@ def get_diffusion_model_impl_qualname(cls, op_name: str) -> str: def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None: return None + @classmethod + def get_diffusion_packed_modules_mapping( + cls, + model_class: type[nn.Module], + ) -> dict[str, list[str]] | None: + return None + @classmethod def get_diffusion_attn_backend_cls( cls, diff --git a/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py index fad4c0edfc3..05079a7e4ae 100644 --- a/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py +++ b/vllm_omni/platforms/npu/models/hunyuan_fused_moe.py @@ -107,12 +107,6 @@ class AscendHunyuanFusedMoE(AscendSharedFusedMoE): def __init__(self, *, prefix: str = "", **kwargs: Any) -> None: super().__init__(prefix=prefix, **kwargs) self._prefix = prefix - self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True) - - def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None: - if self.quant_method: - self.quant_method.process_weights_after_loading(self) - self._init_hook_handle.remove() def forward(self, hidden_states: Any, router_logits: Any) -> Any: _set_hunyuan_fused_moe_forward_context(hidden_states.shape[0]) diff --git a/vllm_omni/platforms/npu/platform.py b/vllm_omni/platforms/npu/platform.py index 53ffe6775a5..1d3e221ffe7 100644 --- a/vllm_omni/platforms/npu/platform.py +++ b/vllm_omni/platforms/npu/platform.py @@ -5,6 +5,7 @@ from typing import Any import torch +import torch.nn as nn from vllm.logger import init_logger from vllm_ascend.platform import NPUPlatform @@ -13,6 +14,12 @@ logger = init_logger(__name__) +_DIFFUSION_PACKED_MODULES_MAPPING = { + "HunyuanImage3Pipeline": { + "experts": ["experts.0.gate_up_proj", "experts.0.down_proj"], + }, +} + class NPUOmniPlatform(OmniPlatform, NPUPlatform): """NPU/Ascend implementation of OmniPlatform. @@ -53,6 +60,13 @@ def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None: prepare_hunyuan_fused_moe_runtime() + @classmethod + def get_diffusion_packed_modules_mapping( + cls, + model_class: type[nn.Module], + ) -> dict[str, list[str]] | None: + return _DIFFUSION_PACKED_MODULES_MAPPING.get(model_class.__name__, None) + @classmethod def get_diffusion_attn_backend_cls( cls,