Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/user_guide/diffusion/quantization/msmodelslim.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# msModelSlim Quantization

## Overview

[msModelSlim](https://github.com/Ascend/msmodelslim) is an Ascend-friendly compression tool focused on acceleration, using compression techniques, and built for Ascend hardware. It includes a series of inference optimization technologies such as quantization and compression, aiming to accelerate large language dense models, MoE models, multimodal understanding models, multimodal generation models, etc.

Once you have a quantized model which is generated by **msModelSlim**, you can use vLLM Omni for inference by specifying the --quantization ascend parameter to enable quantization features.

### Supported Schemes

| Scheme | Bits | Status |
|--------|------|--------|
| W8A8 | 8 | ✅ Supported |
| W4A4 | 4 | Planned |

W8A8 is the first supported scheme. Additional schemes will be added in future releases.

## Model Quantization

Comment thread
jiangmengyu18 marked this conversation as resolved.
The following example shows how to generate W8A8 quantized weights for the [Wan2_2 model](https://gitcode.com/Ascend/msmodelslim/blob/master/example/multimodal_sd/Wan2_2/README.md).

**Quantization Script:**

```bash
msmodelslim quant \
--model_path /path/to/wan2_2_t2v_float_weights \
--save_path /path/to/wan2_2_t2v_quantized_weights \
--device npu \
--model_type Wan2_2 \
--config_path /lab_practice/wan2_2/wan2_2_w8a8f8_mxfp_t2v.yaml \
--trust_remote_code True
```

After quantization completes, the output directory will contain the quantized model files.

For more examples, refer to the [official examples](https://gitcode.com/Ascend/msit/tree/master/msmodelslim/example).

## Configuration

1. **CLI**: pass `--quantization ascend`.
Comment thread
jiangmengyu18 marked this conversation as resolved.

```bash
# Offline inference
python text_to_image.py --model <your-model> --quantization ascend

# Online serving
vllm serve <your-model> --omni --quantization ascend
```

## Supported Models

| Model | HF Models | Recommendation | `ignored_layers` |
|-------|-----------|---------------|------------------|
| HunyuanImage-3.0 | - | All layers | None |
Comment thread
jiangmengyu18 marked this conversation as resolved.

Currently, quantized HunyuanImage-3.0 weights have not been uploaded to public model platforms such as Hugging Face. You can use a [HunyuanImage-3.0-adapted msModelSlim version](https://gitcode.com/betta18/msmodelslim/tree/hyimage3_mxfp8) to generate the quantized weights manually. We will upload the quantized weights as soon as possible.
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def parse_args() -> argparse.Namespace:
default=None,
help=("Custom system prompt. Used when --use-system-prompt is custom. "),
)
current_omni_platform.pre_register_and_update(parser)
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

nullify_stage_engine_defaults(parser)
Expand Down
2 changes: 1 addition & 1 deletion vllm_omni/diffusion/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def from_kwargs(cls, **kwargs: Any) -> "OmniDiffusionConfig":

# Backwards-compatibility: map "quantization" to "quantization_config"
# so callers using the old field name still work.
if "quantization" in kwargs and kwargs.get("quantization_config") is None:
if "quantization" in kwargs and kwargs.get("quantization_config", None) is None:
kwargs["quantization_config"] = kwargs.pop("quantization")
else:
kwargs.pop("quantization", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1484,7 +1484,7 @@ def __init__(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)
if config.use_mixed_mlp_moe > 0:
Expand Down Expand Up @@ -1658,8 +1658,10 @@ def forward(
custom_pos_emb: tuple[torch.FloatTensor] | None = None,
**kwargs,
) -> torch.Tensor:
bsz, q_len, _ = hidden_states.size()
bsz, q_len, hidden_size = hidden_states.size()
hidden_states = hidden_states.reshape(-1, hidden_size)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(bsz, q_len, -1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)

past_key_value: Cache | None = kwargs.get("past_key_value", None)
Expand Down Expand Up @@ -1723,7 +1725,7 @@ def __init__(
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=None,
quant_config=quant_config,
bias=attention_bias,
cache_config=None,
prefix=f"{prefix}.self_attn",
Expand Down Expand Up @@ -1933,7 +1935,7 @@ def __init__(self, config: HunyuanImage3Config, quant_config=None, prefix: str =
layer_idx=int(prefix.split(".")[-1]),
prefix=prefix,
),
prefix=f"{prefix}.layers",
prefix=f"{prefix}.layers" if prefix else "layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand All @@ -1948,7 +1950,7 @@ def _split_qkv_weight(self, qkv: torch.Tensor):
num_attention_heads = self.config.num_attention_heads
num_kv_heads = getattr(self.config, "num_key_value_heads", self.config.num_attention_heads)
num_key_value_groups = num_attention_heads // num_kv_heads
hidden_size = self.config.hidden_size
hidden_size = qkv.shape[1]

if hasattr(self.config, "head_dim"):
attention_head_dim = self.config.head_dim
Expand Down Expand Up @@ -2001,8 +2003,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
split_params_mapping = [
(".gate_up_proj", ".gate_and_up_proj", 2, [(1, 1), (0, 1)], None),
(
".qkv_proj",
".qkv_proj",
".qkv_proj.weight",
".qkv_proj.weight",
num_attention_heads + num_kv_heads * 2,
[("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
self._split_qkv_weight,
),
(
".qkv_proj.weight_scale",
".qkv_proj.weight_scale",
num_attention_heads + num_kv_heads * 2,
[("q", num_attention_heads), ("k", num_kv_heads), ("v", num_kv_heads)],
self._split_qkv_weight,
Expand Down Expand Up @@ -2101,6 +2110,8 @@ def contains_unexpected_keyword(name, keywords):
continue
if "mlp.experts" in name:
continue
if ".qkv_proj" in name and not name.endswith(weight_name):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers.models.siglip2 import Siglip2VisionConfig, Siglip2VisionModel
from transformers.utils.generic import ModelOutput
from vllm.config.vllm import get_current_vllm_config
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper
from vllm.transformers_utils.config import get_config

from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
Expand Down Expand Up @@ -64,6 +64,15 @@ def to_device(data, device):


class HunyuanImage3Pipeline(HunyuanImage3PreTrainedModel, GenerationMixin, DiffusionPipelineProfilerMixin):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "",
},
orig_to_new_substr={
"mlp.gate.wg.": "mlp.gate.",
"gate_and_up_proj.": "gate_up_proj.",
},
)
_PROFILER_TARGETS = [
"model.forward",
"model.layers[0].forward",
Expand Down
19 changes: 19 additions & 0 deletions vllm_omni/diffusion/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import configure_quant_config
from vllm.model_executor.models.registry import _LazyRegisteredModel, _ModelRegistry

from vllm_omni.diffusion.data import OmniDiffusionConfig
Expand All @@ -13,6 +14,7 @@
from vllm_omni.diffusion.forward_context import get_forward_context
from vllm_omni.diffusion.hooks.sequence_parallel import apply_sequence_parallel
from vllm_omni.diffusion.utils.tf_utils import find_module_with_attr
from vllm_omni.platforms import current_omni_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -242,6 +244,22 @@
}


def _prepare_diffusion_quant_config(
od_config: OmniDiffusionConfig,
model_class: type[nn.Module],
) -> None:
"""Prepare diffusion quant config using vLLM-style model bindings."""
quant_config = od_config.quantization_config
if quant_config is None:
return
if hasattr(quant_config, "maybe_update_config"):
quant_config.maybe_update_config(od_config.model)
diffusion_packed_modules_mapping = current_omni_platform.get_diffusion_packed_modules_mapping(model_class)
if diffusion_packed_modules_mapping is not None:
model_class.packed_modules_mapping = diffusion_packed_modules_mapping
configure_quant_config(quant_config, model_class)


def initialize_model(
od_config: OmniDiffusionConfig,
) -> nn.Module:
Expand All @@ -264,6 +282,7 @@ def initialize_model(
"""
model_class = DiffusionModelRegistry._try_load_model_cls(od_config.model_class_name)
if model_class is not None:
_prepare_diffusion_quant_config(od_config, model_class)
model = model_class(od_config=od_config)

vae_pp_size = od_config.parallel_config.vae_patch_parallel_size
Expand Down
16 changes: 16 additions & 0 deletions vllm_omni/diffusion/worker/diffusion_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
from collections.abc import Iterable
from contextlib import AbstractContextManager, nullcontext
from types import SimpleNamespace
from typing import Any

import torch
Expand All @@ -21,6 +22,7 @@
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.logger import init_logger
from vllm.profiler.wrapper import CudaProfilerWrapper, WorkerProfiler
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.mem_utils import GiB_bytes
from vllm.v1.worker.workspace import init_workspace_manager
Expand Down Expand Up @@ -120,6 +122,20 @@ def init_device(self) -> None:
vllm_config.parallel_config.data_parallel_size = self.od_config.parallel_config.data_parallel_size
vllm_config.parallel_config.enable_expert_parallel = self.od_config.parallel_config.enable_expert_parallel
vllm_config.profiler_config = self.od_config.profiler_config
try:
hf_config = get_config(self.od_config.model, trust_remote_code=self.od_config.trust_remote_code)
except ValueError:
hf_config = None
logger.info("Skipping hf_config loading for diffusion model %r", self.od_config.model_class_name)
hf_text_config = get_hf_text_config(hf_config) if hf_config is not None else None
vllm_config.model_config = SimpleNamespace(
hf_config=hf_config,
hf_text_config=hf_text_config,
enforce_eager=self.od_config.enforce_eager,
dtype=self.od_config.dtype,
enable_return_routed_experts=False,
)
vllm_config.quant_config = self.od_config.quantization_config
Comment on lines +125 to +138
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, Only this block is needed to review carefully. @SamitHuang @wtomin @ZJY0516

self.vllm_config = vllm_config

# Initialize distributed environment
Expand Down
7 changes: 4 additions & 3 deletions vllm_omni/engine/async_omni_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,16 +1451,17 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st
if lora_scale is not None:
if not hasattr(cfg.engine_args, "lora_scale") or cfg.engine_args.lora_scale is None:
cfg.engine_args.lora_scale = lora_scale
# Prefer explicit quantization_config; fallback to legacy --quantization.
quantization_config = kwargs.get("quantization_config")
if quantization_config is None:
quantization_config = kwargs.get("quantization")
if quantization_config is not None:
if (
not hasattr(cfg.engine_args, "quantization_config")
or cfg.engine_args.quantization_config is None
):
cfg.engine_args.quantization_config = quantization_config
quantization = kwargs.get("quantization")
if quantization is not None:
if not hasattr(cfg.engine_args, "quantization") or cfg.engine_args.quantization is None:
cfg.engine_args.quantization = quantization
except Exception as e:
logger.warning("Failed to inject LoRA config for stage: %s", e)

Expand Down
8 changes: 8 additions & 0 deletions vllm_omni/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.platforms import Platform

Expand Down Expand Up @@ -71,6 +72,13 @@ def get_diffusion_model_impl_qualname(cls, op_name: str) -> str:
def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None:
return None

@classmethod
def get_diffusion_packed_modules_mapping(
cls,
model_class: type[nn.Module],
) -> dict[str, list[str]] | None:
return None

@classmethod
def get_diffusion_attn_backend_cls(
cls,
Expand Down
6 changes: 0 additions & 6 deletions vllm_omni/platforms/npu/models/hunyuan_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ class AscendHunyuanFusedMoE(AscendSharedFusedMoE):
def __init__(self, *, prefix: str = "", **kwargs: Any) -> None:
super().__init__(prefix=prefix, **kwargs)
self._prefix = prefix
self._init_hook_handle = self.register_forward_pre_hook(self._initialize_kernel_hook, with_kwargs=True)

def _initialize_kernel_hook(self, module: Any, args: Any, kwargs: Any) -> None:
if self.quant_method:
self.quant_method.process_weights_after_loading(self)
self._init_hook_handle.remove()

def forward(self, hidden_states: Any, router_logits: Any) -> Any:
_set_hunyuan_fused_moe_forward_context(hidden_states.shape[0])
Expand Down
14 changes: 14 additions & 0 deletions vllm_omni/platforms/npu/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm_ascend.platform import NPUPlatform

Expand All @@ -13,6 +14,12 @@

logger = init_logger(__name__)

_DIFFUSION_PACKED_MODULES_MAPPING = {
Comment thread
jiangmengyu18 marked this conversation as resolved.
"HunyuanImage3Pipeline": {
"experts": ["experts.0.gate_up_proj", "experts.0.down_proj"],
},
}


class NPUOmniPlatform(OmniPlatform, NPUPlatform):
"""NPU/Ascend implementation of OmniPlatform.
Expand Down Expand Up @@ -53,6 +60,13 @@ def prepare_diffusion_op_runtime(cls, op_name: str, **kwargs: Any) -> None:

prepare_hunyuan_fused_moe_runtime()

@classmethod
def get_diffusion_packed_modules_mapping(
cls,
model_class: type[nn.Module],
) -> dict[str, list[str]] | None:
return _DIFFUSION_PACKED_MODULES_MAPPING.get(model_class.__name__, None)

@classmethod
def get_diffusion_attn_backend_cls(
cls,
Expand Down
Loading