diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py index 43e8ff081893..1bce2a37470b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -15,10 +15,20 @@ ) from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig from sglang.multimodal_gen.runtime.layers.quantization.mxfp4 import Mxfp4Config +from sglang.multimodal_gen.runtime.layers.quantization.mxfp4_npu import ( + NPUMXFP4Config, +) from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config QuantizationMethods = Literal[ - "fp8", "modelopt", "modelopt_fp8", "modelopt_fp4", "modelslim", "mxfp4" + "fp8", + "modelopt", + "modelopt_fp8", + "modelopt_fp4", + "modelslim", + "mxfp8", + "mxfp4", + "mxfp4_npu", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -32,6 +42,7 @@ "fp8": Fp8Config, "mxfp4": Mxfp4Config, "mxfp8": MXFP8Config, + "mxfp4_npu": NPUMXFP4Config, } diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index ee12fba41751..62b84ee17216 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -139,7 +139,16 @@ def _get_scheme_from_parts( ) return ModelSlimMXFP8Scheme() - raise NotImplementedError("No modelslim compatible scheme was found.") + elif quant_type in ("W4A4_MXFP4", "W4A4_MXFP4_DUALSCALE"): + from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp4_scheme import ( + ModelSlimMXFP4Scheme, + ) + + return ModelSlimMXFP4Scheme() + raise NotImplementedError( + f"No modelslim compatible scheme was found for layer '{layer_name}'. " + f"quant_description['{layer_name}.weight'] = '{quant_type}'" + ) def get_scheme( self, layer: torch.nn.Module, layer_name: Optional[str] = None diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py new file mode 100644 index 000000000000..f58b8012f1d0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -0,0 +1,197 @@ +"""ModelSlim MXFP4 scheme for pre-quantized weight inference on Ascend NPU. + +Loads weights pre-quantized by msmodelslim and runs MXFP4 dual-level +matmul at inference via npu_dual_level_quant_matmul. + +Checkpoint tensor formats (verified from msmodelslim export): + weight: [out, in] float8_e4m3fn (FP4 data in fp8 container) + weight_scale: [out, in/32] uint8 (L1 block scales, e8m0+127) + weight_dual_scale:[out, in/512, 1] float32 (L0 coarse scales) + mul_scale: [in] float32 (smooth quant activation scale) + +Reference: MindIE-SD W4A4MXFP4DualQuantLinear +(MindIE-SD/mindiesd/quantization/layer.py) +""" + +from typing import List, Optional + +import torch + +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_npu = current_platform.is_npu() + +if _is_npu: + import torch_npu + +from sglang.multimodal_gen.runtime.models.parameter import ( + BasevLLMParameter, + GroupQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP4_BLOCK_SIZE = 32 +# L1 (dual) scale groups this many L0 blocks together. +# L1 block covers 16 * 32 = 512 elements. +MXFP4_DUAL_LEVEL_RATIO = 16 + + +class ModelSlimMXFP4Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as float8_e4m3fn, shape [out, in]. + # Each byte is a float8 container for FP4 data; the actual FP4 packing + # (npu_dtype_cast → float4_e2m1fn_x2) happens in process_weights_after_loading. + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # L1 block scale: uint8 [out, in/32], e8m0 scale with +127 offset. + scale_dim = input_size_per_partition // MXFP4_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # L0 (coarse) scale for dual-level quantization matmul. + # Each L0 block covers MXFP4_DUAL_LEVEL_RATIO L1 blocks = 16 * 32 = 512 elements. + dual_scale_dim = scale_dim // MXFP4_DUAL_LEVEL_RATIO # in/32 / 16 = in/512 + weight_dual_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, dual_scale_dim, 1), + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_dual_scale", weight_dual_scale) + + # Smooth quant activation scale (mul_scale) from NonFusionSmoothQuantWrapper. + # msmodelslim exports this as `.div.mul_scale` with shape [in]. + # After repack, it becomes `.mul_scale`. + # This is CRITICAL: the offline-quantized weights were calibrated with + # x * mul_scale applied to the activation. Omitting it causes mosaic output. + # Ref: MindIE-SD W4A4MXFP4DualQuantLinear.quant_matmul lines 385-386. + mul_scale = BasevLLMParameter( + data=torch.empty( + (input_size_per_partition,), + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + # If mul_scale is not in the checkpoint (e.g. non-smooth-quant model + # or old repack without .div. handling), initialize to ones so that + # x * 1.0 = x (no-op). fsdp_load.py checks this attribute. + mul_scale.missing_param_init = "ones" + layer.register_parameter("mul_scale", mul_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # Cast weight from fp8 container to FP4 packed format + weight = layer.weight.data + if not weight.is_npu: + weight = weight.to(f"npu:{torch.npu.current_device()}") + weight = torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2) + # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format 29). + # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param + weight = torch_npu.npu_format_cast( + weight.view(torch.int8), 29, customize_dtype=torch.int8 + ) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + + # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] + # The dual-level matmul API expects L1 scales in this 3D format + weight_scale = layer.weight_scale.data + if not weight_scale.is_npu: + weight_scale = weight_scale.to(f"npu:{torch.npu.current_device()}") + weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + # Transform weight_dual_scale: [out, in/512, 1] -> [in/512, out] + weight_dual_scale = layer.weight_dual_scale.data + if not weight_dual_scale.is_npu: + weight_dual_scale = weight_dual_scale.to( + f"npu:{torch.npu.current_device()}" + ) + weight_dual_scale = weight_dual_scale.squeeze(-1).transpose(0, 1).contiguous() + layer.weight_dual_scale = torch.nn.Parameter( + weight_dual_scale, requires_grad=False + ) + + # Move mul_scale to NPU if present and not already there + mul_scale = layer.mul_scale.data + if not mul_scale.is_npu: + mul_scale = mul_scale.to(f"npu:{torch.npu.current_device()}") + layer.mul_scale = torch.nn.Parameter(mul_scale, requires_grad=False) + layer.use_mul_scale = not torch.all(mul_scale == 1.0).item() + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D for npu_dynamic_dual_level_mx_quant + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Apply smooth quant scale before activation quantization. + # The offline-quantized weights were calibrated under x * mul_scale, + # so we MUST apply it here for scale alignment. + # Reference: MindIE-SD W4A4MXFP4DualQuantLinear.quant_matmul + mul_scale = layer.mul_scale + if getattr(layer, "use_mul_scale", True): + x_2d = x_2d * mul_scale.to(x_2d.dtype) + + # Dual-level MXFP4 activation quantization + x1, l0_scale, l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + x_2d, smooth_scale=None + ) + + # Dual-level MXFP4 matmul + output = torch_npu.npu_dual_level_quant_matmul( + x1, + layer.weight, + l0_scale, + layer.weight_dual_scale, + l1_scale, + layer.weight_scale, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + ) + + # Restore original shape + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py new file mode 100644 index 000000000000..3798f36b41ad --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -0,0 +1,201 @@ +"""Online MXFP4 quantization for Diffusion models on Ascend NPU. + +Provides ``NPUMXFP4Config`` (registered as ``"mxfp4_npu"``) and +``NPUMXFP4DiffusionLinearMethod`` which quantises FP16/BF16 weights to MXFP4 +at load time using dual-level MX quantization and uses +``npu_dynamic_dual_level_mx_quant`` + ``npu_dual_level_quant_matmul`` for +inference. + +The ``"mxfp4_npu"`` key is distinct from upstream's ROCm ``"mxfp4"`` +(``Mxfp4Config`` in ``mxfp4.py``) which targets AMD MI350+ via aiter kernels. + +NOTE: Online weight quantization via ``npu_dynamic_dual_level_mx_quant`` is +experimental. MindIE-SD only uses an offline (pre-quantized) path for MXFP4 +weights. The online path quantizes FP16/BF16 weights at load time, which may +produce different numerical results than the offline calibrated path. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_npu = current_platform.is_npu() + +if _is_npu: + import torch_npu + +from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class NPUMXFP4Config(QuantizationConfig): + """Config for online MXFP4 quantization on Ascend NPU (Diffusion).""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> str: + return "mxfp4_npu" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU, not CUDA + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "NPUMXFP4Config": + return cls() + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + return NPUMXFP4DiffusionLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class NPUMXFP4DiffusionLinearMethod(LinearMethodBase): + """Ascend NPU MXFP4 linear method for Diffusion models (dual-level). + + Online mode: loads FP16/BF16 weights → quantises to MXFP4 at load time + via ``npu_dynamic_dual_level_mx_quant``. + Inference: dynamic dual-level MXFP4 activation quant + dual-level matmul. + + Reference: MindIE-SD ``W4A4MXFP4DualQuantLinear`` (offline path only). + """ + + def __init__(self, quant_config: NPUMXFP4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move weight to NPU if needed. dit_cpu_offload defaults to True in + # ServerArgs, which causes fsdp_load to move parameters back to CPU + # after loading. npu_dynamic_dual_level_mx_quant requires an NPU tensor. + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online dual-level MXFP4 weight quantisation. + # NOTE: This is experimental — MindIE-SD only has an offline path for + # MXFP4 weights. We assume npu_dynamic_dual_level_mx_quant can also + # quantise weights (not just activations). + # Returns: (qw, w_dual_scale, w_scale) + # qw — quantized weight in float4_e2m1fn_x2 (2 FP4 packed/byte) + # w_dual_scale — L0-level scale (goes to pos 3 in npu_dual_level_quant_matmul) + # w_scale — L1-level scale (goes to pos 5 in npu_dual_level_quant_matmul) + qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + weight_fp, smooth_scale=None + ) + + # npu_dual_level_quant_matmul requires x2 (weight) in FRACTAL_NZ format. + # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param + qw = torch_npu.npu_format_cast( + qw.view(torch.int8), 29, customize_dtype=torch.int8 + ) + + # x2Level0Scale must be [in/level0_block_size, out] — transpose from + # the [out, in/level0_block_size] shape returned by the quant op. + # Reference: MindIE-SD layer.py:409 + w_dual_scale = w_dual_scale.squeeze(-1).transpose(0, 1).contiguous() + + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_dual_scale = Parameter(w_dual_scale, requires_grad=False) + layer.weight_scale = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] for the quantization operators + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic dual-level MXFP4 activation quantisation + qx, act_l0_scale, act_l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + x_2d, smooth_scale=None + ) + + # Dual-level MXFP4 matmul + # Arg order: act_quant, weight_quant, act_l0_scale, weight_dual_scale, + # act_l1_scale, weight_scale, bias=, output_dtype= + # NOTE: weight is NOT transposed (unlike MXFP8's npu_quant_matmul). + output = torch_npu.npu_dual_level_quant_matmul( + qx, + layer.weight, + act_l0_scale, + layer.weight_dual_scale, + act_l1_scale, + layer.weight_scale, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output diff --git a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py index 4a683d93d7dd..dd037034a1d3 100644 --- a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py @@ -485,15 +485,27 @@ def _resolve_quant_config( resolve quant config from checkpoints' metadata priority: explicit --quantization flag -> model config.json -> safetensors metadata -> format-specific fallback """ - # priority: explicit --quantization flag (e.g. mxfp8, mxfp4, modelslim) + # priority: explicit --quantization flag (e.g. mxfp8, mxfp4_npu, modelslim) if server_args.quantization is not None: from sglang.multimodal_gen.runtime.layers.quantization import ( get_quantization_config, ) + # modelslim requires a per-layer quant description file; load it from + # the component directory rather than returning an empty config. + if server_args.quantization == "modelslim": + return get_quant_config(hf_config, component_model_path) + quant_cls = get_quantization_config(server_args.quantization) return quant_cls.from_config({}) + quant_config = get_quant_config(hf_config, component_model_path) + if quant_config is None and server_args.transformer_weights_path: + for safetensors_file in safetensors_list: + quant_config = get_quant_config_from_safetensors_metadata(safetensors_file) + if quant_config is not None: + return quant_config + arch_config = server_args.pipeline_config.dit_config.arch_config param_names_mapping_dict = arch_config.param_names_mapping reverse_param_names_mapping_dict = getattr( diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 0673e79324be..c543dd4db269 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -1292,8 +1292,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "auto-detected from the checkpoint config or safetensors metadata when " "possible. Applies to both pre-quantized checkpoints and online " "quantization. Use this flag to override auto-detection. " - "Options: 'fp8', 'mxfp8', 'mxfp4', 'modelslim'. " - "Note: MXFP4 requires ROCm and MI350+ (gfx95x)." + "Options: 'fp8', 'mxfp8', 'mxfp4', 'mxfp4_npu', 'modelslim'. " + "Note: 'mxfp4' targets ROCm + MI350+ (gfx95x); " + "'mxfp4_npu' / 'mxfp8' target Ascend NPU (A5 series for mxfp4_npu)." ), ) parser.add_argument( diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 308b229d8593..9623719ca9ea 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -52,6 +52,12 @@ "attn2.to_k_img": "attn2.add_k_proj", "attn2.to_v_img": "attn2.add_v_proj", "attn2.norm_k_img": "attn2.norm_added_k", + # MXFP4 msmodelslim wraps Linear layers with a `.linear.` subpath; + # strip it so keys match the SGLang model parameters. + ".linear.": ".", + # NonFusionSmoothQuantWrapper exports smooth quant scale as `.div.mul_scale`; + # strip `.div.` so it loads as a direct parameter `mul_scale` on the linear layer. + ".div.": ".", } SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] @@ -98,13 +104,10 @@ def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: candidates = sorted(directory.glob(pattern)) if not candidates: raise FileNotFoundError(f"No file matching '{pattern}' found in {directory}") - if len(candidates) > 1: - raise FileNotFoundError( - f"Multiple files matching '{pattern}' found in {directory}: {candidates}" - ) state_dict = {} - state_dict.update(load_file(candidates[0])) + for f in candidates: + state_dict.update(load_file(f)) return state_dict