diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py index 5e3eaf940228..9967879148e3 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -10,8 +10,9 @@ ModelOptFp4Config, ) from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig +from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config -QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim"] +QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim", "mxfp8"] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -20,6 +21,7 @@ "modelopt_fp4": ModelOptFp4Config, "modelslim": ModelSlimConfig, "fp8": Fp8Config, + "mxfp8": MXFP8Config, } diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index afb9a31e4db9..4a9b96f9c9c9 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -119,6 +119,12 @@ def _get_scheme_from_parts( return ModelSlimW4A4Int4( quant_config=self.quant_description, prefix=layer_name ) + elif quant_type == "W8A8_MXFP8": + from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp8_scheme import ( + ModelSlimMXFP8Scheme, + ) + + return ModelSlimMXFP8Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_scheme( diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py new file mode 100644 index 000000000000..c12464d691f7 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -0,0 +1,118 @@ +"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. + +Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, +uint8 scales) and runs MXFP8 matmul at inference. +""" + +from typing import List, Optional + +import torch +import torch_npu + +from sglang.multimodal_gen.runtime.models.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP8_BLOCK_SIZE = 32 + + +class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as float8_e4m3fn, shape [out, in] + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # NOTE: This parameter is intentionally named "weight_scale" (not + # "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader + # matches parameter names to checkpoint keys, and msmodelslim checkpoints + # store this tensor under the key ".weight_scale". + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # weight is already float8_e4m3fn, no cast needed + weight = layer.weight.data + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + + # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2] + weight_scale = layer.weight_scale.data + weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + # npu_dynamic_mx_quant only accepts fp16/bf16 activations + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size]. + # Diffusion transformer inputs are typically 3D [batch, seq, hidden] or + # higher. Flattening to 2D merges all leading dimensions into a single + # token axis so the NPU kernel can compute per-token MXFP8 scales, then + # we restore the original shape from the output. + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py new file mode 100644 index 000000000000..b3dd4612460d --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -0,0 +1,170 @@ +"""Online MXFP8 quantization for Diffusion models on Ascend NPU. + +Provides ``MXFP8Config`` (registered as ``"mxfp8"``) and +``NPUMXFP8DiffusionLinearMethod`` which quantise FP16/BF16 weights to MXFP8 +at load time and use ``npu_dynamic_mx_quant`` + ``npu_quant_matmul`` for +inference, mirroring the LLM-side ``NPUMXFP8LinearMethod``. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import torch +import torch_npu +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +MXFP8_BLOCK_SIZE = 32 + + +class MXFP8Config(QuantizationConfig): + """Config for online MXFP8 quantization on Ascend NPU (Diffusion).""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> str: + return "mxfp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU, not CUDA + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config": + return cls() + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + return NPUMXFP8DiffusionLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class NPUMXFP8DiffusionLinearMethod(LinearMethodBase): + """Ascend NPU MXFP8 linear method for Diffusion models. + + Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. + Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). + """ + + def __init__(self, quant_config: MXFP8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move weight to NPU if needed. We intentionally use a conditional + # move rather than an assert because `dit_cpu_offload` defaults to + # True in ServerArgs, which causes fsdp_load to move every parameter + # back to CPU after loading (even when the target device is NPU). + # npu_dynamic_mx_quant requires an NPU tensor, so we must transfer + # here. The quantized fp8 weights produced below will remain on NPU + # for inference; if the model still needs to be offloaded after + # quantization (e.g. very large model on a small NPU), a higher-level + # offload pass can move them back afterwards. + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online MXFP8 quantisation of weights (block_size=32) + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] so npu_dynamic_mx_quant returns 3D scale + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale_inv.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output diff --git a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py index 385931b19e02..c99cfe41cfa5 100644 --- a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py +++ b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -462,6 +462,7 @@ def load_model_from_full_model_state_dict( "bias", "norm_q", "norm_k", + "weight_scale", ] for new_param_name in unused_keys: meta_sharded_param = meta_sd.get(new_param_name) diff --git a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py index 23c60043ed2e..fd433d37c444 100644 --- a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py @@ -274,8 +274,17 @@ def _resolve_quant_config( ) -> Optional[QuantizationConfig]: """ resolve quant config from checkpoints' metadata - priority: model config.json -> safetensors metadata -> format-specific fallback + priority: explicit --quantization flag -> model config.json -> safetensors metadata -> format-specific fallback """ + # priority: explicit --quantization flag (e.g. mxfp8, mxfp4, modelslim) + if server_args.quantization is not None: + from sglang.multimodal_gen.runtime.layers.quantization import ( + get_quantization_config, + ) + + quant_cls = get_quantization_config(server_args.quantization) + return quant_cls.from_config({}) + quant_config = get_quant_config(hf_config, component_model_path) if quant_config is None and server_args.transformer_weights_path: for safetensors_file in safetensors_list: diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 4fc3e2964ea7..a8f6a413cdf8 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -183,6 +183,10 @@ class ServerArgs: disable_autocast: bool | None = None + # Explicit quantization method override (e.g. "mxfp8", "fp8", "modelslim"). + # When set, the transformer loader will use this instead of auto-detection. + quantization: str | None = None + # Quantization / Nunchaku SVDQuant configuration nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field( default_factory=NunchakuSVDQuantArgs, repr=False @@ -771,6 +775,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Disable autocast for denoising loop and vae decoding in pipeline sampling", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help='Quantization method override (e.g. "mxfp8", "fp8", "modelslim"). ' + "When set, the transformer loader will use this instead of auto-detection.", + ) + # Nunchaku SVDQuant quantization parameters NunchakuSVDQuantArgs.add_cli_args(parser) diff --git a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py index 553b48de8bf3..c51b7f56bd6b 100644 --- a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py @@ -16,9 +16,16 @@ def find_quant_modelslim_config(model_config, component_model_path): + # Try exact name first, then glob for variant filenames (e.g. after repack) quant_config_file = Path(component_model_path, "quant_model_description.json") + if not quant_config_file.is_file(): + candidates = sorted( + Path(component_model_path).glob("quant_model_description*.json") + ) + quant_config_file = candidates[0] if candidates else None + quant_cfg = None - if quant_config_file.is_file(): + if quant_config_file is not None and Path(quant_config_file).is_file(): with open(quant_config_file) as f: quant_cfg = json.load(f) # This field is required for flagless model loading but is not present in diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 2d7132747e7a..308b229d8593 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -1,115 +1,225 @@ -### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py - -import argparse -import json -import pathlib -from typing import Any, Dict, Tuple - -from safetensors.torch import load_file, save_file - -TRANSFORMER_KEYS_RENAME_DICT = { - "time_embedding.0": "condition_embedder.time_embedder.linear_1", - "time_embedding.2": "condition_embedder.time_embedder.linear_2", - "text_embedding.0": "condition_embedder.text_embedder.linear_1", - "text_embedding.2": "condition_embedder.text_embedder.linear_2", - "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", - "head.head": "proj_out", - "modulation": "scale_shift_table", - "ffn.0": "ffn.net.0.proj", - "ffn.2": "ffn.net.2", - # Hack to swap the layer names - # The original model calls the norms in following order: norm1, norm3, norm2 - # We convert it to: norm1, norm2, norm3 - "norm2": "norm__placeholder", - "norm3": "norm2", - "norm__placeholder": "norm3", - # For the I2V model - "img_emb.proj.0": "condition_embedder.image_embedder.norm1", - "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", - "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", - "img_emb.proj.4": "condition_embedder.image_embedder.norm2", - # for the FLF2V model - "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", - # Add attention component mappings - "self_attn.q": "attn1.to_q", - "self_attn.k": "attn1.to_k", - "self_attn.v": "attn1.to_v", - "self_attn.o": "attn1.to_out.0", - "self_attn.norm_q": "attn1.norm_q", - "self_attn.norm_k": "attn1.norm_k", - "cross_attn.q": "attn2.to_q", - "cross_attn.k": "attn2.to_k", - "cross_attn.v": "attn2.to_v", - "cross_attn.o": "attn2.to_out.0", - "cross_attn.norm_q": "attn2.norm_q", - "cross_attn.norm_k": "attn2.norm_k", - "attn2.to_k_img": "attn2.add_k_proj", - "attn2.to_v_img": "attn2.add_v_proj", - "attn2.norm_k_img": "attn2.norm_added_k", -} - - -def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: - if model_type == "Wan-T2V-14B": - RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT - return RENAME_DICT - - -def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: - dict[new_key] = dict.pop(old_key) - - -def load_sharded_safetensors(path: pathlib.Path): - file_path = path - state_dict = {} - state_dict.update(load_file(file_path)) - return state_dict - - -def convert_transformer(model_type: str, model_dir: str, output_dir: str): - pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - RENAME_DICT = get_transformer_config(model_type) - - original_state_dict = load_sharded_safetensors( - pathlib.Path(model_dir, "*model*.safetensors") - ) - with open(pathlib.Path(model_dir, "*quant_model_description*.json")) as f: - original_quant_config = json.load(f) - - for key in list(original_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_dict_(original_state_dict, key, new_key) - update_dict_(original_quant_config, key, new_key) - - save_file( - original_state_dict, - pathlib.Path(output_dir, "diffusion_pytorch_model.safetensors"), - ) - - with open(pathlib.Path(output_dir, "quant_model_description.json"), "w") as f: - json.dump(original_quant_config, f) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--input-path", type=str, required=True) - parser.add_argument("--output-path", type=str, required=True) - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "high_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer"), - ) - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "low_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer_2"), - ) +### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py + +import argparse +import json +import pathlib +import shutil +from typing import Any, Dict, List + +from safetensors.torch import load_file, save_file + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # for the FLF2V model + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", +} + +SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] + +# Cascade models have two transformers (high_noise + low_noise) +CASCADE_MODEL_TYPES = {"Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B"} + + +def get_transformer_config(model_type: str) -> Dict[str, Any]: + if model_type in SUPPORTED_MODEL_TYPES: + return TRANSFORMER_KEYS_RENAME_DICT + else: + raise ValueError( + f"Unsupported model_type: {model_type}. Supported: {SUPPORTED_MODEL_TYPES}" + ) + + +def get_transformer_dirs(model_type: str) -> List[str]: + """Return the list of transformer directory names for a given model type.""" + if model_type in CASCADE_MODEL_TYPES: + return ["transformer", "transformer_2"] + return ["transformer"] + + +def get_quant_subpath( + model_type: str, quant_path: pathlib.Path, transformer_dir: str +) -> pathlib.Path: + """Return the quant weights subdirectory for a given transformer.""" + if model_type in CASCADE_MODEL_TYPES: + sub = ( + "high_noise_model" + if transformer_dir == "transformer" + else "low_noise_model" + ) + return quant_path / sub + return quant_path + + +def update_dict_(d: Dict[str, Any], old_key: str, new_key: str) -> None: + d[new_key] = d.pop(old_key) + + +def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: + candidates = sorted(directory.glob(pattern)) + if not candidates: + raise FileNotFoundError(f"No file matching '{pattern}' found in {directory}") + if len(candidates) > 1: + raise FileNotFoundError( + f"Multiple files matching '{pattern}' found in {directory}: {candidates}" + ) + + state_dict = {} + state_dict.update(load_file(candidates[0])) + return state_dict + + +def convert_transformer( + model_type: str, model_dir: pathlib.Path, output_dir: pathlib.Path +) -> None: + """Convert a single quantized transformer directory into Diffusers format.""" + model_path = pathlib.Path(model_dir) + out_path = pathlib.Path(output_dir) + out_path.mkdir(parents=True, exist_ok=True) + RENAME_DICT = get_transformer_config(model_type) + + state_dict = load_sharded_safetensors(model_path, "quant_model_weight*.safetensors") + + json_candidates = sorted(model_path.glob("quant_model_description*.json")) + if not json_candidates: + raise FileNotFoundError( + f"No quant_model_description*.json found in {model_path}" + ) + with open(json_candidates[0]) as f: + quant_config = json.load(f) + + for key in list(state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + if new_key != key: + update_dict_(state_dict, key, new_key) + # The quant JSON only covers quantized layers, not all model keys + if key in quant_config: + update_dict_(quant_config, key, new_key) + + save_file(state_dict, out_path / "diffusion_pytorch_model.safetensors") + + with open(out_path / "quant_model_description.json", "w") as f: + json.dump(quant_config, f, indent=2) + + +def repack( + model_type: str, + original_model_path: pathlib.Path, + quant_path: pathlib.Path, + output_path: pathlib.Path, +) -> None: + """ + Full one-step repack workflow: + 1. Copy the original HF Diffusers model to output_path, excluding transformer dir(s). + 2. For each transformer: convert quant weights and copy config.json from original. + """ + transformer_dirs = get_transformer_dirs(model_type) + + # Step 1: Copy original model, skipping transformer dirs (they will be replaced) + logger.debug(f"Step 1: Copying original model to {output_path}") + logger.debug(f" (skipping: {transformer_dirs})") + shutil.copytree( + str(original_model_path), + str(output_path), + ignore=shutil.ignore_patterns(*transformer_dirs), + ) + + # Step 2+: Convert each transformer + for i, tdir in enumerate(transformer_dirs): + q_path = get_quant_subpath(model_type, quant_path, tdir) + out_tdir = output_path / tdir + logger.debug( + f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})..." + ) + convert_transformer(model_type, q_path, out_tdir) + + # Copy config.json from the original transformer dir + src_config = original_model_path / tdir / "config.json" + if src_config.is_file(): + shutil.copy2(str(src_config), str(out_tdir / "config.json")) + logger.debug(f" Copied config.json from original {tdir}/") + + logger.info(f"\nDone! Repacked model saved to: {output_path}") + + +def get_args(): + parser = argparse.ArgumentParser( + description="Repack msmodelslim quantized Wan2.2 weights into HF Diffusers format" + ) + parser.add_argument( + "--model-type", + type=str, + required=True, + choices=SUPPORTED_MODEL_TYPES, + help="Model type to convert", + ) + parser.add_argument( + "--original-model-path", + type=str, + required=True, + help="Path to the original HF Diffusers model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers)", + ) + parser.add_argument( + "--quant-path", + type=str, + required=True, + help="Path to msmodelslim quantized weights directory", + ) + parser.add_argument( + "--output-path", + type=str, + required=True, + help="Output path for the repacked model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers-MXFP8)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + repack( + model_type=args.model_type, + original_model_path=pathlib.Path(args.original_model_path), + quant_path=pathlib.Path(args.quant_path), + output_path=pathlib.Path(args.output_path), + ) diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 788620a317bb..393c966c06fc 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -1,6 +1,8 @@ from typing import TYPE_CHECKING, Optional import torch +import torch_npu +from torch.nn.parameter import Parameter from sglang.srt.hardware_backend.npu.utils import npu_format_cast from sglang.srt.layers.quantization.base_config import LinearMethodBase @@ -8,6 +10,11 @@ if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig +MXFP8_BLOCK_SIZE = 32 +_FLOAT8_E8M0FNU_DTYPE = getattr( + torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) +) + class _NPULinearMethodBase(LinearMethodBase): @@ -111,6 +118,363 @@ def apply( ) +class NPUMXFP8LinearMethod(_NPULinearMethodBase): + """Ascend NPU MXFP8 linear method for LLM (SRT) models. + + Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. + Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). + """ + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.parameter import ModelWeightParameter + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move weight to NPU if needed (cpu offload may have moved it back to CPU) + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online MXFP8 quantisation of weights (block_size=32) + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) + layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False) + layer.weight_scale_inv = Parameter( + w_scale.transpose(0, 1).contiguous(), requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] for npu_dynamic_mx_quant + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul (weight & scale already transposed at load time) + output = torch_npu.npu_quant_matmul( + qx, + layer.weight, + layer.weight_scale_inv, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + pertoken_scale=input_scale, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) + + +class NPUMXFP4W4A8LinearMethod(_NPULinearMethodBase): + """Ascend NPU W4A8 online quantization: MXFP4 weights + MXFP8 activations. + + Weight quantization flow (process_weights_after_loading): + BF16/FP16 weight → npu_dynamic_dual_level_mx_quant → FP4 + l0_scale(FP32) + l1_scale(FP8_E8M0) + → npu_format_cast to FRACTAL_NZ (required by npu_dual_level_quant_matmul) + → w_dual_scale transposed to [in/512, out] (required by matmul API) + + Inference flow (apply): + FP16/BF16 activation → npu_dynamic_dual_level_mx_quant → FP4 + act_l0_scale + act_l1_scale + → npu_dual_level_quant_matmul(FP4_act, FP4_weight, scales...) → FP16/BF16 output + + Note: The "A8" refers to the MXFP8 intermediate scale format (FP8_E8M0 l1_scale). + The actual matmul compute is W4A4 (both operands in FP4) since there is no + W4A8 mixed-precision kernel in the current torch_npu public API. + """ + + _FLOAT4_E2M1FN_X2_DTYPE = getattr( + torch_npu, "float4_e2m1fn_x2", getattr(torch, "float4_e2m1fn_x2", None) + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.parameter import ModelWeightParameter + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise to MXFP4 in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + import logging + + from sglang.srt.utils import get_npu_memory_capacity + + _logger = logging.getLogger(__name__) + + # Heuristic hardware check: npu_dynamic_dual_level_mx_quant requires Ascend 950. + # Atlas A2/A3 have ≤64 GB per card; Ascend 950 has ≥96 GB per card. + npu_mem_mb = get_npu_memory_capacity() + if npu_mem_mb < 96 * 1024: + _logger.warning( + "MXFP4 W4A8 dual-level quantization may not be supported on this " + "hardware (detected NPU memory %.1f GB < 96 GB). " + "npu_dynamic_dual_level_mx_quant requires Ascend 950 (Atlas A3). " + "Continuing — expect a RuntimeError if the kernel is unavailable.", + npu_mem_mb / 1024, + ) + + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move to NPU if needed (cpu offload may have put it on CPU) + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online MXFP4 dual-level quantisation of weights + # qw: float4_e2m1fn_x2, shape [out, in] + # w_dual_scale: float32, shape [out, in/512, 1] (L0) + # w_scale: float8_e8m0, shape [out, (ceil(in/32)+1)//2, 2] (L1) + try: + qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + weight_fp, smooth_scale=None + ) + except (RuntimeError, AttributeError) as e: + raise RuntimeError( + "npu_dynamic_dual_level_mx_quant failed — this operation requires " + "Ascend 950 (Atlas A3). Atlas 800I A2/A3 and earlier chips do NOT " + "support DualLevelQuantBatchMatmul. " + f"Original error: {e}" + ) from e + + # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format=29) + # view as int8 first because npu_format_cast only accepts int-dtype tensors + qw = torch_npu.npu_format_cast(qw.view(torch.int8), 29) + + # npu_dual_level_quant_matmul expects x2_level0_scale shape [in/512, out]: + # squeeze the trailing dim-1 axis, then transpose + w_dual_scale = w_dual_scale.squeeze(-1).transpose(0, 1).contiguous() + + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_dual_scale = Parameter(w_dual_scale, requires_grad=False) + layer.weight_scale = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] for dual-level quant API + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP4 activation quantisation (W4 activations → A4 for matmul) + qx, act_l0_scale, act_l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + x_2d, smooth_scale=None + ) + + # MXFP4 matmul: W4A4 compute (weight already in NZ format + transposed scales) + output = torch_npu.npu_dual_level_quant_matmul( + qx, + layer.weight, + act_l0_scale, + layer.weight_dual_scale, + act_l1_scale, + layer.weight_scale, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) + + +class NPUW4A8DynamicLinearMethod(_NPULinearMethodBase): + """Ascend NPU W4A8 offline quantization linear method. + + Offline mode: loads ModelSlim pre-quantized INT4 weights. + For ``new_quant_version=True`` (version "1.0.0"): 2 int4 values are pre-packed + into 1 int8 in the checkpoint (shape ``[N/2, K]``). + For old version: plain int4 stored as int8 (shape ``[N, K]``). + + Uses ``torch_npu.npu_weight_quant_batchmatmul`` for inference — activations + stay in high precision and INT4 weights are dequantized on-the-fly. + """ + + def __init__( + self, + group_size: int = 256, + new_quant_version: bool = True, + ): + super().__init__() + self.group_size = group_size + self.new_quant_version = new_quant_version + + @staticmethod + def _process_scale_second( + weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False, + ): + """Merge per-channel (L1) and per-group (L2) scales into antiquant_scale. + + Args: + weight: weight after transpose, shape ``[K, N/2]`` (new) or ``[K, N]`` (old) + scale: per-channel L1 scale, shape ``[N]`` + per_group_scale: per-group L2 scale after transpose, shape ``[K//group_size, N]`` + is_new_quant: whether weight dim is compressed (N/2) + + Returns: + (antiquant_scale, bias): ``antiquant_scale`` shape ``[K//group_size, N]``; + ``bias`` is non-None only for old version (asymmetric compensation term). + """ + k, n_compressed = weight.shape + group_num, n_scale = per_group_scale.shape + + # Logical N dimension + n = n_compressed * 2 if is_new_quant else n_compressed + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n + ) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale, bias + + def process_weights_after_loading(self, layer: torch.nn.Module): + from sglang.srt.hardware_backend.npu.utils import npu_format_cast + + # Transpose [N, K] → [K, N] (or [N/2, K] → [K, N/2] for packed) + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + # Cast to FRACTAL_NZ format for NPU matmul efficiency + layer.weight.data = npu_format_cast(layer.weight.data) + + # Flatten per-channel scales to 1-D float32 + layer.weight_scale.data = layer.weight_scale.data.flatten().to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + + # Merge L1/L2 scales: weight_scale_second loaded as [N, K//group_size], + # transpose to [K//group_size, N] for process_scale_second + layer.weight_scale_second.data, scale_bias = self._process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, + ) + + if self.new_quant_version: + # Handle optional scale_bias parameter + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + # Pack 4 int8 (2×int4) into int32 for NPU kernel + assert ( + layer.weight.data.shape[-1] % 4 == 0 + ), f"Last dim of weight must be divisible by 4, got {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view(torch.int32).contiguous() + else: + # Old version: use NPU int4-pack conversion + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32) + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Weight-dequant path: INT4 weights dequantized on-the-fly, activations in high precision + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + class NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer): diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 8a6b1b06e193..d07117c8a7cd 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -36,6 +36,7 @@ def override_quantization_method(self, *args, **kwargs): from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config +from sglang.srt.layers.quantization.npu_mxfp4 import NPUMxfp4Config from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.quark.quark import QuarkConfig @@ -77,6 +78,7 @@ def override_quantization_method(self, *args, **kwargs): "auto-round": AutoRoundConfig, "modelslim": ModelSlimConfig, "quark_int4fp8_moe": QuarkInt4Fp8Config, + "mxfp4_w4a8_npu": NPUMxfp4Config, } diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 7182e3d57ba6..689e85534c1f 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -57,10 +57,7 @@ requant_weight_ue8m0_inplace, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod -from sglang.srt.layers.quantization.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) +from sglang.srt.layers.quantization.marlin_utils_fp8 import prepare_fp8_layer_for_marlin from sglang.srt.layers.quantization.unquant import ( UnquantizedFusedMoEMethod, UnquantizedLinearMethod, @@ -172,6 +169,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.half] def get_min_capability(self) -> int: + if is_npu(): + return 0 # NPU bypasses CUDA capability checks return 100 if self.use_mxfp8 else 80 @classmethod @@ -225,6 +224,12 @@ def get_quant_method( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() + if is_npu() and self.use_mxfp8: + from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUMXFP8LinearMethod, + ) + + return NPUMXFP8LinearMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( @@ -646,7 +651,7 @@ def apply( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_marlin: - return apply_fp8_marlin_linear( + return torch.ops.sglang.apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, @@ -1001,15 +1006,23 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: w2_weight_scale, requires_grad=False ) layer.w2_input_scale = None - - if _use_aiter: + if _use_aiter: + # add this section for MI300 + # Pre-shuffle weights + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) + elif _use_aiter: # Pre-shuffle weights - t = shuffle_weight(layer.w13_weight, (16, 16)) - layer.w13_weight.copy_(t) - del t - t = shuffle_weight(layer.w2_weight, (16, 16)) - layer.w2_weight.copy_(t) - del t + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) elif _is_cpu: assert ( _is_cpu_amx_available diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 84acecccc415..c6c2e8f81f3e 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -17,6 +17,7 @@ from sglang.srt.layers.quantization.modelslim.schemes import ( ModelSlimW4A4Int4, ModelSlimW4A4Int4MoE, + ModelSlimW4A8Int8, ModelSlimW4A8Int8MoE, ModelSlimW8A8Int8, ModelSlimW8A8Int8MoE, @@ -190,6 +191,22 @@ def _get_scheme_from_parts( return ModelSlimW4A4Int4( quant_config=self.quant_description, prefix=layer_name ) + elif quant_type == "W4A8_DYNAMIC": + return ModelSlimW4A8Int8( + quant_config=self.quant_description, prefix=layer_name + ) + elif quant_type == "W8A8_MXFP8": + from sglang.srt.layers.quantization.modelslim.schemes.modelslim_mxfp8 import ( + ModelSlimMXFP8Scheme, + ) + + return ModelSlimMXFP8Scheme() + elif quant_type == "W4A8_MXFP": + from sglang.srt.layers.quantization.modelslim.schemes.modelslim_mxfp4_w4a8 import ( + ModelSlimMXFP4W4A8Scheme, + ) + + return ModelSlimMXFP4W4A8Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_linear_scheme( diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py index c349fd3c4251..c7755427a946 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py @@ -1,8 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +# NOTE: Import order is critical to avoid circular dependency. +# modelslim_mxfp8 imports ModelSlimLinearScheme from this package, +# so the base class must be imported first. +# isort: off from .modelslim_scheme import ModelSlimLinearScheme, ModelSlimMoEScheme +from .modelslim_mxfp8 import ModelSlimMXFP8Scheme +from .modelslim_mxfp4_w4a8 import ModelSlimMXFP4W4A8Scheme + +# isort: on from .modelslim_w4a4_int4 import ModelSlimW4A4Int4 from .modelslim_w4a4_int4_moe import ModelSlimW4A4Int4MoE +from .modelslim_w4a8_int8 import ModelSlimW4A8Int8 from .modelslim_w4a8_int8_moe import ModelSlimW4A8Int8MoE from .modelslim_w8a8_int8 import ModelSlimW8A8Int8 from .modelslim_w8a8_int8_moe import ModelSlimW8A8Int8MoE @@ -10,9 +19,12 @@ __all__ = [ "ModelSlimLinearScheme", "ModelSlimMoEScheme", + "ModelSlimMXFP8Scheme", + "ModelSlimMXFP4W4A8Scheme", "ModelSlimW8A8Int8", "ModelSlimW4A4Int4", "ModelSlimW4A4Int4MoE", + "ModelSlimW4A8Int8", "ModelSlimW4A8Int8MoE", "ModelSlimW8A8Int8MoE", ] diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py new file mode 100644 index 000000000000..1e4913d989f2 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py @@ -0,0 +1,104 @@ +"""ModelSlim W4A8_MXFP scheme for pre-quantized weight inference on Ascend NPU (SRT). + +Loads weights pre-quantized by msmodelslim: + weight: float8_e4m3fn, shape [out, in], group_size=32 + weight_scale: uint8 (+127 biased), shape [out, in/32] + +Inference: + activation → npu_dynamic_mx_quant(float8_e4m3fn) → qx + per-token scale + npu_quant_matmul(qx, weight, weight_scale, scale_dtype=FP8_E8M0) +""" + +from typing import List, Optional + +import torch +import torch_npu + +from sglang.srt.layers.parameter import GroupQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP4_W4A8_BLOCK_SIZE = 32 + +_FLOAT8_E8M0FNU_DTYPE = getattr( + torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) +) + + +class ModelSlimMXFP4W4A8Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + scale_dim = input_size_per_partition // MXFP4_W4A8_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # weight_scale: [out, in/32] → reshape [out, in/64, 2] → transpose [in/64, out, 2] + # weight: [out, in] → transpose [in, out] + n_dim, k_dim = layer.weight_scale.data.shape + layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) + layer.weight.data = layer.weight.data.transpose(0, 1) + layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + output = torch_npu.npu_quant_matmul( + qx, + layer.weight, + layer.weight_scale, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + pertoken_scale=input_scale, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP4_W4A8_BLOCK_SIZE], + ) + + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py new file mode 100644 index 000000000000..02fd515db594 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py @@ -0,0 +1,108 @@ +"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU (SRT). + +Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, +uint8 scales) and runs MXFP8 matmul at inference. +""" + +from typing import List, Optional + +import torch +import torch_npu + +from sglang.srt.layers.parameter import GroupQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP8_BLOCK_SIZE = 32 +_FLOAT8_E8M0FNU_DTYPE = getattr( + torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) +) + + +class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as float8_e4m3fn, shape [out, in] + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # NOTE: Named "weight_scale" (not "weight_scale_inv") to match the + # checkpoint key exported by msmodelslim. + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # Pre-transpose weight and scale to [in, out] for npu_quant_matmul. + # Use .data assignment without .contiguous() to preserve the transpose + # view strides — npu_quant_matmul reads strides correctly and calling + # .contiguous() would reorder data, breaking the block-scale mapping. + n_dim, k_dim = layer.weight_scale.data.shape + layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) + layer.weight.data = layer.weight.data.transpose(0, 1) + layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size] + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul (weight & scale already transposed at load time) + output = torch_npu.npu_quant_matmul( + qx, + layer.weight, + layer.weight_scale, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + pertoken_scale=input_scale, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8.py new file mode 100644 index 000000000000..4a07ffb67746 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8.py @@ -0,0 +1,131 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch + +from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUW4A8DynamicLinearMethod, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme +from sglang.srt.utils import set_weight_attrs + + +class ModelSlimW4A8Int8(ModelSlimLinearScheme): + """ModelSlim offline W4A8 Dense Linear scheme. + + Handles ``W4A8_DYNAMIC`` quant_type from ``quant_model_description.json``. + + Weight layout in the checkpoint: + - ``new_quant_version`` (version == "1.0.0"): INT4×2 pre-packed into INT8, + so on-disk shape is ``[N/2, K]``. + - Old version: each INT8 stores one INT4, on-disk shape is ``[N, K]``. + + Delegates weight processing and matmul to ``NPUW4A8DynamicLinearMethod`` + which uses ``torch_npu.npu_weight_quant_batchmatmul`` (weight-dequant path). + """ + + def __init__( + self, + quant_config: Dict[str, Any], + prefix: str, + ): + self.quant_config = quant_config + self.group_size: int = quant_config.get("group_size", 256) + self.new_quant_version: bool = quant_config.get("version", "0") == "1.0.0" + self.kernel = NPUW4A8DynamicLinearMethod( + group_size=self.group_size, + new_quant_version=self.new_quant_version, + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # ── Weight ────────────────────────────────────────────────────────── + # new_quant_version: 2 INT4 packed per INT8 → shape [N/2, K] + # old version : 1 INT4 per INT8 → shape [N, K] + weight_n = ( + output_size_per_partition // 2 + if self.new_quant_version + else output_size_per_partition + ) + weight = torch.nn.Parameter( + torch.empty(weight_n, input_size_per_partition, dtype=torch.int8), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + # ── Per-channel L1 scale & offset: [N, 1] ─────────────────────────── + weight_scale = torch.nn.Parameter( + torch.empty(output_size_per_partition, 1, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_scale, {"output_dim": 0}) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, extra_weight_attrs) + + weight_offset = torch.nn.Parameter( + torch.empty(output_size_per_partition, 1, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_offset, {"output_dim": 0}) + layer.register_parameter("weight_offset", weight_offset) + set_weight_attrs(weight_offset, extra_weight_attrs) + + # ── Per-group L2 scale & offset: [N, K//group_size] ───────────────── + # Note: for RowParallelLinear (K partitioned), input_dim=1 would be needed; + # for ColumnParallelLinear (N partitioned), output_dim=0 suffices. + # Initial implementation covers the column-parallel case. + group_num = input_size_per_partition // self.group_size + weight_scale_second = torch.nn.Parameter( + torch.empty(output_size_per_partition, group_num, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_scale_second, {"output_dim": 0}) + layer.register_parameter("weight_scale_second", weight_scale_second) + set_weight_attrs(weight_scale_second, extra_weight_attrs) + + weight_offset_second = torch.nn.Parameter( + torch.empty(output_size_per_partition, group_num, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_offset_second, {"output_dim": 0}) + layer.register_parameter("weight_offset_second", weight_offset_second) + set_weight_attrs(weight_offset_second, extra_weight_attrs) + + # ── scale_bias (new_quant_version only): [N, 1] ───────────────────── + # Shape is [N, 16] for RowParallelLinear (down_proj / o_proj), + # but [N, 1] for ColumnParallelLinear. Using [N, 1] for simplicity; + # process_weights_after_loading handles both shapes dynamically. + if self.new_quant_version: + scale_bias = torch.nn.Parameter( + torch.empty(output_size_per_partition, 1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs(scale_bias, {"output_dim": 0}) + layer.register_parameter("scale_bias", scale_bias) + set_weight_attrs(scale_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.kernel.apply(layer, x, bias) diff --git a/python/sglang/srt/layers/quantization/npu_mxfp4.py b/python/sglang/srt/layers/quantization/npu_mxfp4.py new file mode 100644 index 000000000000..a62526a85be1 --- /dev/null +++ b/python/sglang/srt/layers/quantization/npu_mxfp4.py @@ -0,0 +1,122 @@ +"""Ascend NPU MXFP4 W4A8 online quantization config. + +Triggered by ``--quantization mxfp4_w4a8_npu``. + +Online mode: loads FP16/BF16 weights, quantises to MXFP4 (dual-level) in +``process_weights_after_loading``. During inference, activations are +dynamically quantised to MXFP4 and ``npu_dual_level_quant_matmul`` is used +for the matrix multiply. + +Hardware requirement: Ascend 950 (DualLevelQuantBatchMatmul is NOT supported +on Atlas A2/A3 – check your hardware before enabling). +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.unquant import ( + UnquantizedFusedMoEMethod, + UnquantizedLinearMethod, +) +from sglang.srt.layers.quantization.utils import is_layer_skipped + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class NPUMxfp4Config(QuantizationConfig): + """Quantization config for Ascend NPU MXFP4 W4A8 online quantization. + + Weights are quantised online to MXFP4 dual-level format during model + loading. Activations are quantised dynamically to MXFP4 at inference + time. The matmul is executed via ``torch_npu.npu_dual_level_quant_matmul``. + """ + + def __init__( + self, + ignored_layers: Optional[List[str]] = None, + packed_modules_mapping: Optional[Dict[str, str]] = None, + ): + super().__init__() + self.ignored_layers = ignored_layers or [] + self.packed_modules_mapping = packed_modules_mapping or {} + + @classmethod + def get_name(cls) -> str: + return "mxfp4_w4a8_npu" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU bypasses CUDA capability checks + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict) -> "NPUMxfp4Config": + ignored_layers = cls.get_from_keys_or( + config, ["ignored_layers", "modules_to_not_convert"], None + ) + if ignored_layers: + normalized: List[str] = [] + for layer in ignored_layers: + base = layer.removeprefix("model.") + normalized.append(base) + normalized.append(f"model.{base}") + ignored_layers = normalized + packed_modules_mapping = ( + cls.get_from_keys_or(config, ["packed_modules_mapping"], {}) or {} + ) + return cls( + ignored_layers=ignored_layers, + packed_modules_mapping=packed_modules_mapping, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix, + self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUMXFP4W4A8LinearMethod, + ) + + return NPUMXFP4W4A8LinearMethod(self) + elif isinstance(layer, FusedMoE): + # MoE MXFP4 not yet implemented; fall back to unquantised + logger.warning( + "MXFP4 W4A8 quantization is not yet supported for FusedMoE layers " + "(prefix=%s). Falling back to unquantized MoE — MoE weights will " + "run in full precision (BF16/FP16).", + prefix, + ) + return UnquantizedFusedMoEMethod( + layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe + ) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 518d250211f8..06ff843e1485 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -39,7 +39,11 @@ if _is_npu: import torch_npu - from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa + + try: + from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa + except ImportError: + fused_rope_qk_mqa = None if _is_hip: from sglang.srt.layers.attention.utils import ( @@ -257,7 +261,10 @@ def forward_npu( else: cos_sin = self.cos_sin_cache.index_select(0, positions) - if query.shape[0] * query.shape[1] < 65535: + if ( + fused_rope_qk_mqa is not None + and query.shape[0] * query.shape[1] < 65535 + ): return fused_rope_qk_mqa( query, key, diff --git a/python/sglang/srt/models/transformers.py b/python/sglang/srt/models/transformers.py index 0ea9da14a1be..36a9eb48b7e6 100644 --- a/python/sglang/srt/models/transformers.py +++ b/python/sglang/srt/models/transformers.py @@ -99,6 +99,7 @@ def replace_linear_class( linear: nn.Linear, style: Literal["colwise", "rowwise"], quant_config: QuantizationConfig, + prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. @@ -107,6 +108,7 @@ def replace_linear_class( linear (nn.Linear): `nn.Linear` to be replaced. style (str): Tensor parallel style of the new linear, e.g. "colwise". quant_config (QuantConfig): Quantization config for the new linear. + prefix (str): Layer name prefix used for per-layer quantization dispatch. Returns: Union[ColumnParallelLinear, RowParallelLinear]: The new linear. """ @@ -136,6 +138,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, + prefix=prefix, ) @@ -227,14 +230,14 @@ def _tensor_parallel(module: nn.Module, prefix: str = ""): child_module, nn.Linear ): new_module = replace_linear_class( - child_module, style, self.quant_config + child_module, style, self.quant_config, prefix=qual_name ) setattr(module, child_name, new_module) self.log_replacement(qual_name, child_module, new_module) else: _tensor_parallel(child_module, prefix=qual_name) - _tensor_parallel(self.model) + _tensor_parallel(self.model, prefix="model") def replace_vocab_embed_class(self, module: nn.Module): # Use native set input embeddings diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d6a746ee3b3b..325b97d9d389 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -116,6 +116,8 @@ "auto-round", "compressed-tensors", # for Ktransformers "modelslim", # for NPU + "mxfp4_w4a8_npu", # for NPU W4A8 + "mxfp4_w4a4_npu", # for NPU W4A4 "quark_int4fp8_moe", ]