From ef874c0a1c92bf29a35e7f2e7efaf2bdaed748fa Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 18 Mar 2026 15:53:53 +0800 Subject: [PATCH 01/29] =?UTF-8?q?=E2=9C=A8=20feat(npu):=20add=20online=20M?= =?UTF-8?q?XFP8=20quantization=20support=20for=20Ascend=20NPU=20(Path=20B)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU, supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8 checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant + npu_quant_matmul with group_sizes=[1,1,32]). --- .../npu/quantization/mxfp8_method_npu.py | 152 ++++++++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 6 + .../ascend/test_ascend_mxfp8_quantization.py | 103 ++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py create mode 100644 test/srt/ascend/test_ascend_mxfp8_quantization.py diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py new file mode 100644 index 000000000000..10bb30aa4fab --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py @@ -0,0 +1,152 @@ +from typing import TYPE_CHECKING, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.base_config import LinearMethodBase + +if TYPE_CHECKING: + from sglang.srt.layers.quantization.fp8 import Fp8Config + + +class NPUMXFP8LinearMethod(LinearMethodBase): + """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. + + Supports two modes: + - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 + at weight loading time. + - Offline quantization: loads pre-quantized FP8 weights with block scales + from a serialized checkpoint. + + Uses torch_npu APIs: + - npu_dynamic_mx_quant: dynamic MXFP8 activation quantization (block_size=32) + - npu_quant_matmul: MXFP8 matrix multiplication with group_sizes=[1,1,32] + """ + + MXFP8_BLOCK_SIZE = 32 + + def __init__(self, quant_config: "Fp8Config"): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + is_serialized = self.quant_config.is_checkpoint_fp8_serialized + + # Weight: fp8 if serialized checkpoint, else original dtype (will be + # quantized in process_weights_after_loading) + weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + if is_serialized: + # Block scale: one scale per block of 32 elements along input dim. + # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. + block_k = self.MXFP8_BLOCK_SIZE + scale_cols = (input_size_per_partition + block_k - 1) // block_k + scale = BlockQuantScaleParameter( + data=torch.zeros( + output_size_per_partition, + scale_cols, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale.format_ue8m0 = True + layer.register_parameter("weight_scale_inv", scale) + else: + layer.register_parameter("weight_scale_inv", None) + + def process_weights_after_loading(self, layer: torch.nn.Module): + import torch_npu + + is_serialized = self.quant_config.is_checkpoint_fp8_serialized + + if is_serialized: + # Checkpoint already has fp8 weights + uint8 scales. + # Ensure weight is float8_e4m3fn. + if layer.weight.data.dtype != torch.float8_e4m3fn: + layer.weight = Parameter( + torch_npu.npu_dtype_cast( + layer.weight.data, torch_npu.float8_e4m3fn + ), + requires_grad=False, + ) + else: + layer.weight.requires_grad_(False) + + # Scale is already uint8 (e8m0fnu), keep as-is. + layer.weight_scale_inv.requires_grad_(False) + else: + # Online quantization: quantize FP16/BF16 weights to MXFP8. + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + import torch_npu + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Dynamic MXFP8 activation quantization (block_size=32) + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 quantized matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale_inv.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, self.MXFP8_BLOCK_SIZE], + ) + return output diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 1e12baff13cf..21fa1f97b94d 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -222,6 +222,12 @@ def get_quant_method( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() + if _is_npu and self.use_mxfp8: + from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( + NPUMXFP8LinearMethod, + ) + + return NPUMXFP8LinearMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( diff --git a/test/srt/ascend/test_ascend_mxfp8_quantization.py b/test/srt/ascend/test_ascend_mxfp8_quantization.py new file mode 100644 index 000000000000..e7af4ff97f1a --- /dev/null +++ b/test/srt/ascend/test_ascend_mxfp8_quantization.py @@ -0,0 +1,103 @@ +""" +Usage: +python3 -m unittest test_ascend_mxfp8_quantization.TestAscendMXFP8.test_gsm8k +""" + +import os +import time +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + +if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 +) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" + + +class TestAscendMXFP8(CustomTestCase): + """Test online MXFP8 quantization (--quantization mxfp8) on Ascend NPU.""" + + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen2.5-0.5B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--quantization", + "mxfp8", + "--device", + "npu", + "--attention-backend", + "ascend", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + url = urlparse(self.base_url) + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + metrics = run_eval(args) + print(metrics) + self.assertGreaterEqual(metrics["accuracy"], 0.25) + self.assertGreaterEqual(metrics["output_throughput"], 500) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.perf_counter() + res = self.run_decode(max_tokens) + tok = time.perf_counter() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + + if is_in_ci(): + self.assertGreaterEqual(throughput, 20) + + +if __name__ == "__main__": + unittest.main() From d2d19c6f12fde51e6903f206cf493f6bdb7dea55 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 18 Mar 2026 16:00:42 +0800 Subject: [PATCH 02/29] =?UTF-8?q?=E2=9C=A8=20feat(diffusion):=20add=20onli?= =?UTF-8?q?ne=20MXFP8=20quantization=20support=20for=20Wan2.2=20on=20Ascen?= =?UTF-8?q?d=20NPU?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2 and other diffusion models on Ascend NPU. Also adds explicit quantization field to diffusion ServerArgs so online quantization can be specified without pre-quantized weights. --- .../runtime/layers/quantization/__init__.py | 4 +- .../runtime/layers/quantization/mxfp8_npu.py | 153 ++++++++++++++++++ .../component_loaders/transformer_loader.py | 11 +- .../multimodal_gen/runtime/server_args.py | 4 + 4 files changed, 170 insertions(+), 2 deletions(-) create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py index 3d78bb58cd9e..625e50af8115 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -7,8 +7,9 @@ ) from sglang.multimodal_gen.runtime.layers.quantization.fp8 import Fp8Config from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig +from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config -QuantizationMethods = Literal["fp8", "modelslim"] +QuantizationMethods = Literal["fp8", "modelslim", "mxfp8"] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -16,6 +17,7 @@ _CUSTOMIZED_METHOD_TO_QUANT_CONFIG = { "modelslim": ModelSlimConfig, "fp8": Fp8Config, + "mxfp8": MXFP8Config, } diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py new file mode 100644 index 000000000000..6984a59d192e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -0,0 +1,153 @@ +"""Online MXFP8 quantization for Diffusion models on Ascend NPU. + +Provides ``MXFP8Config`` (registered as ``"mxfp8"``) and +``NPUMXFP8DiffusionLinearMethod`` which quantise FP16/BF16 weights to MXFP8 +at load time and use ``npu_dynamic_mx_quant`` + ``npu_quant_matmul`` for +inference, mirroring the LLM-side ``NPUMXFP8LinearMethod``. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.layers.linear import LinearMethodBase +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter + +logger = logging.getLogger(__name__) + +MXFP8_BLOCK_SIZE = 32 + + +class MXFP8Config(QuantizationConfig): + """Config for online MXFP8 quantization on Ascend NPU (Diffusion).""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> str: + return "mxfp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU, not CUDA + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config": + return cls() + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + from sglang.multimodal_gen.runtime.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + return NPUMXFP8DiffusionLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class NPUMXFP8DiffusionLinearMethod(LinearMethodBase): + """Ascend NPU MXFP8 linear method for Diffusion models. + + Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. + Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). + """ + + def __init__(self, quant_config: MXFP8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + import torch_npu + + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Online MXFP8 quantisation of weights (block_size=32) + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + import torch_npu + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale_inv.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + return output diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py index 658689ec2c98..13c11bb09c88 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py @@ -82,7 +82,16 @@ def _resolve_quant_config( safetensors_list: list[str], component_model_path: str, ) -> Optional[dict]: - # priority: model config.json → safetensors metadata → nunchaku config + # priority: explicit --quantization flag → model config.json + # → safetensors metadata → nunchaku config + if server_args.quantization is not None: + from sglang.multimodal_gen.runtime.layers.quantization import ( + get_quantization_config, + ) + + quant_cls = get_quantization_config(server_args.quantization) + return quant_cls.from_config({}) + quant_config = get_quant_config(hf_config, component_model_path) if quant_config is None and server_args.transformer_weights_path: # try to read quantization_config from the safetensors metadata header diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index acd41d59a7a4..72c2cd5f63e9 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -167,6 +167,10 @@ class ServerArgs: disable_autocast: bool | None = None + # Explicit quantization method override (e.g. "mxfp8", "fp8", "modelslim"). + # When set, the transformer loader will use this instead of auto-detection. + quantization: str | None = None + # Quantization / Nunchaku SVDQuant configuration nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field( default_factory=NunchakuSVDQuantArgs, repr=False From c838adef02cb6f830f0073b557b74fc239d390fc Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 19 Mar 2026 09:47:48 +0800 Subject: [PATCH 03/29] :bug: fix(diffusion): fix npu method call error --- .../runtime/layers/quantization/mxfp8_npu.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index 6984a59d192e..fbeed2a03a4c 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -113,6 +113,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) + # Ensure weight is on NPU before calling npu_dynamic_mx_quant + if not weight_fp.is_npu: + weight_fp = weight_fp.npu() + # Online MXFP8 quantisation of weights (block_size=32) qw, w_scale = torch_npu.npu_dynamic_mx_quant( weight_fp, dst_type=torch_npu.float8_e4m3fn @@ -133,6 +137,13 @@ def apply( x = x.to(torch.bfloat16) original_dtype = torch.bfloat16 + # npu_quant_matmul requires 3D input; flatten leading dims if needed + input_shape = x.shape + if x.dim() > 3: + x = x.reshape(-1, x.shape[-2], x.shape[-1]) + elif x.dim() == 2: + x = x.unsqueeze(0) + # Dynamic MXFP8 activation quantisation qx, input_scale = torch_npu.npu_dynamic_mx_quant( x, dst_type=torch_npu.float8_e4m3fn @@ -150,4 +161,10 @@ def apply( output_dtype=original_dtype, group_sizes=[1, 1, MXFP8_BLOCK_SIZE], ) + + # Restore original shape (replace last dim with output features) + if len(input_shape) != 3: + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + return output From be3b684c10a72213e33f2bc50ced54dfc26eea1b Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 19 Mar 2026 10:32:11 +0800 Subject: [PATCH 04/29] :bug: fix(diffusion): fix MXFP8 quantization scale dimension mismatch on NPU - Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call - Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul) - Simplify output shape restoration logic Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4 --- .../runtime/layers/quantization/mxfp8_npu.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index fbeed2a03a4c..7e7cc37028a9 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -137,16 +137,13 @@ def apply( x = x.to(torch.bfloat16) original_dtype = torch.bfloat16 - # npu_quant_matmul requires 3D input; flatten leading dims if needed + # Flatten to 2D [tokens, hidden] so npu_dynamic_mx_quant returns 3D scale input_shape = x.shape - if x.dim() > 3: - x = x.reshape(-1, x.shape[-2], x.shape[-1]) - elif x.dim() == 2: - x = x.unsqueeze(0) + x_2d = x.reshape(-1, x.shape[-1]) # Dynamic MXFP8 activation quantisation qx, input_scale = torch_npu.npu_dynamic_mx_quant( - x, dst_type=torch_npu.float8_e4m3fn + x_2d, dst_type=torch_npu.float8_e4m3fn ) # MXFP8 matmul @@ -163,8 +160,7 @@ def apply( ) # Restore original shape (replace last dim with output features) - if len(input_shape) != 3: - output_shape = list(input_shape[:-1]) + [output.shape[-1]] - output = output.reshape(output_shape) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) return output From fd79b235af48200c40f9469ab6d54348880e8637 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 20 Mar 2026 09:38:27 +0800 Subject: [PATCH 05/29] :recycle: refactor(mxfp8): split linear method into config and NPU layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 按 reviewer 建议重构架构分层: - 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights) - 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用 - 改进架构分层,符合现有 NPU INT8 方法模式 --- .../npu/quantization/mxfp8_method_npu.py | 93 +++-------------- python/sglang/srt/layers/quantization/fp8.py | 99 ++++++++++++++++++- 2 files changed, 107 insertions(+), 85 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py index 10bb30aa4fab..97f150040553 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py @@ -1,97 +1,27 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.parameter import ( - BlockQuantScaleParameter, - ModelWeightParameter, -) -from sglang.srt.layers.quantization.base_config import LinearMethodBase -if TYPE_CHECKING: - from sglang.srt.layers.quantization.fp8 import Fp8Config +class NPUMXFP8LinearMethod: + """Ascend NPU MXFP8 weight processing and kernel calls. + This class handles NPU-specific operations: + - process_weights_after_loading: dtype casting and online quantization + - apply: dynamic activation quantization + MXFP8 matmul -class NPUMXFP8LinearMethod(LinearMethodBase): - """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. - - Supports two modes: - - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 - at weight loading time. - - Offline quantization: loads pre-quantized FP8 weights with block scales - from a serialized checkpoint. - - Uses torch_npu APIs: - - npu_dynamic_mx_quant: dynamic MXFP8 activation quantization (block_size=32) - - npu_quant_matmul: MXFP8 matrix multiplication with group_sizes=[1,1,32] + Weight creation and config management are handled by + MXFP8LinearAscendMethod in fp8.py. """ MXFP8_BLOCK_SIZE = 32 - def __init__(self, quant_config: "Fp8Config"): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, + def process_weights_after_loading( + self, layer: torch.nn.Module, is_serialized: bool ): - output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") - - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - - is_serialized = self.quant_config.is_checkpoint_fp8_serialized - - # Weight: fp8 if serialized checkpoint, else original dtype (will be - # quantized in process_weights_after_loading) - weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - if is_serialized: - # Block scale: one scale per block of 32 elements along input dim. - # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. - block_k = self.MXFP8_BLOCK_SIZE - scale_cols = (input_size_per_partition + block_k - 1) // block_k - scale = BlockQuantScaleParameter( - data=torch.zeros( - output_size_per_partition, - scale_cols, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - scale.format_ue8m0 = True - layer.register_parameter("weight_scale_inv", scale) - else: - layer.register_parameter("weight_scale_inv", None) - - def process_weights_after_loading(self, layer: torch.nn.Module): import torch_npu - is_serialized = self.quant_config.is_checkpoint_fp8_serialized - if is_serialized: # Checkpoint already has fp8 weights + uint8 scales. # Ensure weight is float8_e4m3fn. @@ -113,6 +43,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) + if not weight_fp.is_npu: + weight_fp = weight_fp.npu() + qw, w_scale = torch_npu.npu_dynamic_mx_quant( weight_fp, dst_type=torch_npu.float8_e4m3fn ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 21fa1f97b94d..8d6fdba452bd 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -223,11 +223,7 @@ def get_quant_method( ): return UnquantizedLinearMethod() if _is_npu and self.use_mxfp8: - from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( - NPUMXFP8LinearMethod, - ) - - return NPUMXFP8LinearMethod(self) + return MXFP8LinearAscendMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( @@ -245,6 +241,99 @@ def get_scaled_act_names(self) -> List[str]: return [] +class MXFP8LinearAscendMethod(LinearMethodBase): + """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. + + Supports two modes: + - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 + at weight loading time. + - Offline quantization: loads pre-quantized FP8 weights with block scales + from a serialized checkpoint. + + Weight creation is handled here; weight processing and kernel calls are + delegated to NPUMXFP8LinearMethod in the NPU hardware backend. + """ + + MXFP8_BLOCK_SIZE = 32 + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( + NPUMXFP8LinearMethod + ) + + self.npu_method = NPUMXFP8LinearMethod() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + is_serialized = self.quant_config.is_checkpoint_fp8_serialized + + # Weight: fp8 if serialized checkpoint, else original dtype (will be + # quantized in process_weights_after_loading) + weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + if is_serialized: + # Block scale: one scale per block of 32 elements along input dim. + # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. + block_k = self.MXFP8_BLOCK_SIZE + scale_cols = (input_size_per_partition + block_k - 1) // block_k + scale = BlockQuantScaleParameter( + data=torch.zeros( + output_size_per_partition, + scale_cols, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale.format_ue8m0 = True + layer.register_parameter("weight_scale_inv", scale) + else: + layer.register_parameter("weight_scale_inv", None) + + def process_weights_after_loading(self, layer: torch.nn.Module): + self.npu_method.process_weights_after_loading( + layer, self.quant_config.is_checkpoint_fp8_serialized + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.npu_method.apply(layer, x, bias) + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. From 490ad0b11897062d9731b8b160ab6dbfd441c964 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 20 Mar 2026 11:43:07 +0800 Subject: [PATCH 06/29] :sparkles: feat(diffusion): add offline MXFP8 pre-quantized weight support for Wan2.2 TI2V --- .../runtime/layers/quantization/modelslim.py | 8 ++ .../quantization/modelslim_mxfp8_scheme.py | 117 ++++++++++++++++++ .../runtime/loader/fsdp_load.py | 1 + .../runtime/utils/quantization_utils.py | 7 +- 4 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index afb9a31e4db9..4d43c50e3675 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -119,6 +119,14 @@ def _get_scheme_from_parts( return ModelSlimW4A4Int4( quant_config=self.quant_description, prefix=layer_name ) + elif quant_type == "W8A8_MXFP8": + from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp8_scheme import ( + ModelSlimMXFP8Scheme, + ) + + return ModelSlimMXFP8Scheme( + quant_config=self.quant_description, prefix=layer_name + ) raise NotImplementedError("No modelslim compatible scheme was found.") def get_scheme( diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py new file mode 100644 index 000000000000..eb40a59fa0ef --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -0,0 +1,117 @@ +"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. + +Loads weights pre-quantized by msmodelslim (int8 storage for float8_e4m3fn, +uint8 storage for float8_e8m0fnu scales) and runs MXFP8 matmul at inference. +""" + +from typing import Dict, List, Optional + +import torch + +from sglang.multimodal_gen.runtime.models.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP8_BLOCK_SIZE = 32 + + +class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): + + def __init__(self, quant_config: Dict[str, any], prefix: str): + self.quant_config = quant_config + self.prefix = prefix + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as int8 (storage for float8_e4m3fn) + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.int8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8 (storage for float8_e8m0fnu) + # shape: [out, in/32 * 2] + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE * 2 + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + import torch_npu + + # Cast int8 → float8_e4m3fn + weight = layer.weight.data + weight = torch_npu.npu_dtype_cast(weight, torch_npu.float8_e4m3fn) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + + # Reshape weight_scale: [out, in/32*2] → [out, in/32, 2] + weight_scale = layer.weight_scale.data + weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + import torch_npu + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D for npu_dynamic_mx_quant + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output diff --git a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py index dcfbb6eac76a..fab7dd96df3f 100644 --- a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py +++ b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -378,6 +378,7 @@ def load_model_from_full_model_state_dict( "bias", "norm_q", "norm_k", + "weight_scale", ] for new_param_name in unused_keys: if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS): diff --git a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py index e1489780d4f7..c235c0523624 100644 --- a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py @@ -16,9 +16,14 @@ def find_quant_modelslim_config(model_config, component_model_path): + # Try exact name first, then glob for variant filenames (e.g. after repack) quant_config_file = Path(component_model_path, "quant_model_description.json") + if not quant_config_file.is_file(): + candidates = sorted(Path(component_model_path).glob("quant_model_description*.json")) + quant_config_file = candidates[0] if candidates else None + quant_cfg = None - if quant_config_file.is_file(): + if quant_config_file is not None and Path(quant_config_file).is_file(): with open(quant_config_file) as f: quant_cfg = json.load(f) # This field is required for flagless model loading but is not present in From cc80690cfdee8656ed9e4e388017a85418ffcaf3 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Mon, 23 Mar 2026 12:04:22 +0800 Subject: [PATCH 07/29] :bug: fix(diffusion): correct MXFP8 weight dtype and scale shape Fix weight loading for msmodelslim pre-quantized MXFP8 weights: - Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors) - Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export) - Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2] --- .../quantization/modelslim_mxfp8_scheme.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py index eb40a59fa0ef..4485d251fdf6 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -1,7 +1,7 @@ """ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. -Loads weights pre-quantized by msmodelslim (int8 storage for float8_e4m3fn, -uint8 storage for float8_e8m0fnu scales) and runs MXFP8 matmul at inference. +Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, +uint8 scales) and runs MXFP8 matmul at inference. """ from typing import Dict, List, Optional @@ -36,11 +36,11 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) - # msmodelslim exports weight as int8 (storage for float8_e4m3fn) + # msmodelslim exports weight as float8_e4m3fn, shape [out, in] weight = ModelWeightParameter( data=torch.empty( (output_size_per_partition, input_size_per_partition), - dtype=torch.int8, + dtype=torch.float8_e4m3fn, ), input_dim=1, output_dim=0, @@ -48,9 +48,8 @@ def create_weights( ) layer.register_parameter("weight", weight) - # msmodelslim exports weight_scale as uint8 (storage for float8_e8m0fnu) - # shape: [out, in/32 * 2] - scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE * 2 + # msmodelslim exports weight_scale as uint8, shape [out, in/32] + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE weight_scale = GroupQuantScaleParameter( data=torch.empty( (output_size_per_partition, scale_dim), @@ -63,14 +62,11 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module): - import torch_npu - - # Cast int8 → float8_e4m3fn + # weight is already float8_e4m3fn, no cast needed weight = layer.weight.data - weight = torch_npu.npu_dtype_cast(weight, torch_npu.float8_e4m3fn) layer.weight = torch.nn.Parameter(weight, requires_grad=False) - # Reshape weight_scale: [out, in/32*2] → [out, in/32, 2] + # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2] weight_scale = layer.weight_scale.data weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) From b9aa78553f9b65c878c1e873757ea0899a5a17a3 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 24 Mar 2026 11:25:25 +0800 Subject: [PATCH 08/29] =?UTF-8?q?=E2=9C=A8=20feat(wan22):=20Redesigned=20t?= =?UTF-8?q?he=20wan=5Frepack=20tool.=20Now=20support=20one-click=20weight?= =?UTF-8?q?=20processing.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sglang/multimodal_gen/tools/wan_repack.py | 330 ++++++++++++------ 1 file changed, 215 insertions(+), 115 deletions(-) diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 2d7132747e7a..3fd23a10cbd0 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -1,115 +1,215 @@ -### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py - -import argparse -import json -import pathlib -from typing import Any, Dict, Tuple - -from safetensors.torch import load_file, save_file - -TRANSFORMER_KEYS_RENAME_DICT = { - "time_embedding.0": "condition_embedder.time_embedder.linear_1", - "time_embedding.2": "condition_embedder.time_embedder.linear_2", - "text_embedding.0": "condition_embedder.text_embedder.linear_1", - "text_embedding.2": "condition_embedder.text_embedder.linear_2", - "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", - "head.head": "proj_out", - "modulation": "scale_shift_table", - "ffn.0": "ffn.net.0.proj", - "ffn.2": "ffn.net.2", - # Hack to swap the layer names - # The original model calls the norms in following order: norm1, norm3, norm2 - # We convert it to: norm1, norm2, norm3 - "norm2": "norm__placeholder", - "norm3": "norm2", - "norm__placeholder": "norm3", - # For the I2V model - "img_emb.proj.0": "condition_embedder.image_embedder.norm1", - "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", - "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", - "img_emb.proj.4": "condition_embedder.image_embedder.norm2", - # for the FLF2V model - "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", - # Add attention component mappings - "self_attn.q": "attn1.to_q", - "self_attn.k": "attn1.to_k", - "self_attn.v": "attn1.to_v", - "self_attn.o": "attn1.to_out.0", - "self_attn.norm_q": "attn1.norm_q", - "self_attn.norm_k": "attn1.norm_k", - "cross_attn.q": "attn2.to_q", - "cross_attn.k": "attn2.to_k", - "cross_attn.v": "attn2.to_v", - "cross_attn.o": "attn2.to_out.0", - "cross_attn.norm_q": "attn2.norm_q", - "cross_attn.norm_k": "attn2.norm_k", - "attn2.to_k_img": "attn2.add_k_proj", - "attn2.to_v_img": "attn2.add_v_proj", - "attn2.norm_k_img": "attn2.norm_added_k", -} - - -def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: - if model_type == "Wan-T2V-14B": - RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT - return RENAME_DICT - - -def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: - dict[new_key] = dict.pop(old_key) - - -def load_sharded_safetensors(path: pathlib.Path): - file_path = path - state_dict = {} - state_dict.update(load_file(file_path)) - return state_dict - - -def convert_transformer(model_type: str, model_dir: str, output_dir: str): - pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - RENAME_DICT = get_transformer_config(model_type) - - original_state_dict = load_sharded_safetensors( - pathlib.Path(model_dir, "*model*.safetensors") - ) - with open(pathlib.Path(model_dir, "*quant_model_description*.json")) as f: - original_quant_config = json.load(f) - - for key in list(original_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_dict_(original_state_dict, key, new_key) - update_dict_(original_quant_config, key, new_key) - - save_file( - original_state_dict, - pathlib.Path(output_dir, "diffusion_pytorch_model.safetensors"), - ) - - with open(pathlib.Path(output_dir, "quant_model_description.json"), "w") as f: - json.dump(original_quant_config, f) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--input-path", type=str, required=True) - parser.add_argument("--output-path", type=str, required=True) - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "high_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer"), - ) - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "low_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer_2"), - ) +### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py + +import argparse +import json +import pathlib +import shutil +from typing import Any, Dict, List + +from safetensors.torch import load_file, save_file + +TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # for the FLF2V model + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", +} + +SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] + +# Cascade models have two transformers (high_noise + low_noise) +CASCADE_MODEL_TYPES = {"Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B"} + + +def get_transformer_config(model_type: str) -> Dict[str, Any]: + if model_type in SUPPORTED_MODEL_TYPES: + return TRANSFORMER_KEYS_RENAME_DICT + else: + raise ValueError( + f"Unsupported model_type: {model_type}. Supported: {SUPPORTED_MODEL_TYPES}" + ) + + +def get_transformer_dirs(model_type: str) -> List[str]: + """Return the list of transformer directory names for a given model type.""" + if model_type in CASCADE_MODEL_TYPES: + return ["transformer", "transformer_2"] + return ["transformer"] + + +def get_quant_subpath( + model_type: str, quant_path: pathlib.Path, transformer_dir: str +) -> pathlib.Path: + """Return the quant weights subdirectory for a given transformer.""" + if model_type in CASCADE_MODEL_TYPES: + sub = "high_noise_model" if transformer_dir == "transformer" else "low_noise_model" + return quant_path / sub + return quant_path + + +def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> None: + dict[new_key] = dict.pop(old_key) + + +def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: + candidates = sorted(directory.glob(pattern)) + if not candidates: + raise FileNotFoundError(f"No file matching '{pattern}' found in {directory}") + if len(candidates) > 1: + raise FileNotFoundError( + f"Multiple files matching '{pattern}' found in {directory}: {candidates}" + ) + print(f" Loading: {candidates[0].name}") + state_dict = {} + state_dict.update(load_file(candidates[0])) + return state_dict + + +def convert_transformer( + model_type: str, model_dir: pathlib.Path, output_dir: pathlib.Path +) -> None: + """Convert a single quantized transformer directory into Diffusers format.""" + model_path = pathlib.Path(model_dir) + out_path = pathlib.Path(output_dir) + out_path.mkdir(parents=True, exist_ok=True) + RENAME_DICT = get_transformer_config(model_type) + + state_dict = load_sharded_safetensors(model_path, "quant_model_weight*.safetensors") + + json_candidates = sorted(model_path.glob("quant_model_description*.json")) + if not json_candidates: + raise FileNotFoundError( + f"No quant_model_description*.json found in {model_path}" + ) + with open(json_candidates[0]) as f: + quant_config = json.load(f) + + for key in list(state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + if new_key != key: + update_dict_(state_dict, key, new_key) + # The quant JSON only covers quantized layers, not all model keys + if key in quant_config: + update_dict_(quant_config, key, new_key) + + save_file(state_dict, out_path / "diffusion_pytorch_model.safetensors") + + with open(out_path / "quant_model_description.json", "w") as f: + json.dump(quant_config, f, indent=2) + + +def repack( + model_type: str, + original_model_path: pathlib.Path, + quant_path: pathlib.Path, + output_path: pathlib.Path, +) -> None: + """ + Full one-step repack workflow: + 1. Copy the original HF Diffusers model to output_path, excluding transformer dir(s). + 2. For each transformer: convert quant weights and copy config.json from original. + """ + transformer_dirs = get_transformer_dirs(model_type) + + # Step 1: Copy original model, skipping transformer dirs (they will be replaced) + print(f"Step 1: Copying original model to {output_path}") + print(f" (skipping: {transformer_dirs})") + shutil.copytree( + str(original_model_path), + str(output_path), + ignore=shutil.ignore_patterns(*transformer_dirs), + ) + + # Step 2+: Convert each transformer + for i, tdir in enumerate(transformer_dirs): + q_path = get_quant_subpath(model_type, quant_path, tdir) + out_tdir = output_path / tdir + print(f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})...") + convert_transformer(model_type, q_path, out_tdir) + + # Copy config.json from the original transformer dir + src_config = original_model_path / tdir / "config.json" + if src_config.is_file(): + shutil.copy2(str(src_config), str(out_tdir / "config.json")) + print(f" Copied config.json from original {tdir}/") + + print(f"\nDone! Repacked model saved to: {output_path}") + + +def get_args(): + parser = argparse.ArgumentParser( + description="Repack msmodelslim quantized Wan2.2 weights into HF Diffusers format" + ) + parser.add_argument( + "--model-type", + type=str, + required=True, + choices=SUPPORTED_MODEL_TYPES, + help="Model type to convert", + ) + parser.add_argument( + "--original-model-path", + type=str, + required=True, + help="Path to the original HF Diffusers model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers)", + ) + parser.add_argument( + "--quant-path", + type=str, + required=True, + help="Path to msmodelslim quantized weights directory", + ) + parser.add_argument( + "--output-path", + type=str, + required=True, + help="Output path for the repacked model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers-MXFP8)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + repack( + model_type=args.model_type, + original_model_path=pathlib.Path(args.original_model_path), + quant_path=pathlib.Path(args.quant_path), + output_path=pathlib.Path(args.output_path), + ) From 22bee9e9c148d89e3c6794b9652387ea4b46893c Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 24 Mar 2026 18:15:52 +0800 Subject: [PATCH 09/29] :recycle: refactor(mxfp8): hoist imports and replace print with logger --- .../quantization/modelslim_mxfp8_scheme.py | 9 +++++-- .../runtime/layers/quantization/mxfp8_npu.py | 15 ++++-------- .../runtime/utils/quantization_utils.py | 4 +++- .../sglang/multimodal_gen/tools/wan_repack.py | 24 +++++++++++++------ .../npu/quantization/mxfp8_method_npu.py | 5 +--- python/sglang/srt/layers/quantization/fp8.py | 2 +- 6 files changed, 34 insertions(+), 25 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py index 4485d251fdf6..a5c92f039e94 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional import torch +import torch_npu from sglang.multimodal_gen.runtime.models.parameter import ( GroupQuantScaleParameter, @@ -77,14 +78,18 @@ def apply_weights( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - import torch_npu original_dtype = x.dtype if original_dtype not in (torch.float16, torch.bfloat16): + # npu_dynamic_mx_quant only accepts fp16/bf16 activations x = x.to(torch.bfloat16) original_dtype = torch.bfloat16 - # Flatten to 2D for npu_dynamic_mx_quant + # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size]. + # Diffusion transformer inputs are typically 3D [batch, seq, hidden] or + # higher. Flattening to 2D merges all leading dimensions into a single + # token axis so the NPU kernel can compute per-token MXFP8 scales, then + # we restore the original shape from the output. input_shape = x.shape x_2d = x.reshape(-1, x.shape[-1]) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index 7e7cc37028a9..1d4d88adb8ec 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -8,20 +8,21 @@ from __future__ import annotations -import logging from typing import Any, Dict, List, Optional import torch +import torch_npu from torch.nn.parameter import Parameter -from sglang.multimodal_gen.runtime.layers.linear import LinearMethodBase +from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger -logger = logging.getLogger(__name__) +logger = init_logger(__name__) MXFP8_BLOCK_SIZE = 32 @@ -55,8 +56,6 @@ def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: - from sglang.multimodal_gen.runtime.layers.linear import LinearBase - if isinstance(layer, LinearBase): return NPUMXFP8DiffusionLinearMethod(self) return None @@ -107,15 +106,13 @@ def create_weights( layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - import torch_npu weight_fp = layer.weight.data if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) # Ensure weight is on NPU before calling npu_dynamic_mx_quant - if not weight_fp.is_npu: - weight_fp = weight_fp.npu() + assert weight_fp.is_npu # Online MXFP8 quantisation of weights (block_size=32) qw, w_scale = torch_npu.npu_dynamic_mx_quant( @@ -130,8 +127,6 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - import torch_npu - original_dtype = x.dtype if original_dtype not in (torch.float16, torch.bfloat16): x = x.to(torch.bfloat16) diff --git a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py index c235c0523624..c2380a409ce2 100644 --- a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py @@ -19,7 +19,9 @@ def find_quant_modelslim_config(model_config, component_model_path): # Try exact name first, then glob for variant filenames (e.g. after repack) quant_config_file = Path(component_model_path, "quant_model_description.json") if not quant_config_file.is_file(): - candidates = sorted(Path(component_model_path).glob("quant_model_description*.json")) + candidates = sorted( + Path(component_model_path).glob("quant_model_description*.json") + ) quant_config_file = candidates[0] if candidates else None quant_cfg = None diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 3fd23a10cbd0..d7951e6c3c5f 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -8,6 +8,10 @@ from safetensors.torch import load_file, save_file +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + TRANSFORMER_KEYS_RENAME_DICT = { "time_embedding.0": "condition_embedder.time_embedder.linear_1", "time_embedding.2": "condition_embedder.time_embedder.linear_2", @@ -77,7 +81,11 @@ def get_quant_subpath( ) -> pathlib.Path: """Return the quant weights subdirectory for a given transformer.""" if model_type in CASCADE_MODEL_TYPES: - sub = "high_noise_model" if transformer_dir == "transformer" else "low_noise_model" + sub = ( + "high_noise_model" + if transformer_dir == "transformer" + else "low_noise_model" + ) return quant_path / sub return quant_path @@ -94,7 +102,7 @@ def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: raise FileNotFoundError( f"Multiple files matching '{pattern}' found in {directory}: {candidates}" ) - print(f" Loading: {candidates[0].name}") + state_dict = {} state_dict.update(load_file(candidates[0])) return state_dict @@ -149,8 +157,8 @@ def repack( transformer_dirs = get_transformer_dirs(model_type) # Step 1: Copy original model, skipping transformer dirs (they will be replaced) - print(f"Step 1: Copying original model to {output_path}") - print(f" (skipping: {transformer_dirs})") + logger.debug(f"Step 1: Copying original model to {output_path}") + logger.debug(f" (skipping: {transformer_dirs})") shutil.copytree( str(original_model_path), str(output_path), @@ -161,16 +169,18 @@ def repack( for i, tdir in enumerate(transformer_dirs): q_path = get_quant_subpath(model_type, quant_path, tdir) out_tdir = output_path / tdir - print(f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})...") + logger.debug( + f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})..." + ) convert_transformer(model_type, q_path, out_tdir) # Copy config.json from the original transformer dir src_config = original_model_path / tdir / "config.json" if src_config.is_file(): shutil.copy2(str(src_config), str(out_tdir / "config.json")) - print(f" Copied config.json from original {tdir}/") + logger.debug(f" Copied config.json from original {tdir}/") - print(f"\nDone! Repacked model saved to: {output_path}") + logger.info(f"\nDone! Repacked model saved to: {output_path}") def get_args(): diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py index 97f150040553..c83b77cad7f4 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py @@ -1,6 +1,7 @@ from typing import Optional import torch +import torch_npu from torch.nn.parameter import Parameter @@ -20,8 +21,6 @@ class NPUMXFP8LinearMethod: def process_weights_after_loading( self, layer: torch.nn.Module, is_serialized: bool ): - import torch_npu - if is_serialized: # Checkpoint already has fp8 weights + uint8 scales. # Ensure weight is float8_e4m3fn. @@ -58,8 +57,6 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - import torch_npu - original_dtype = x.dtype if original_dtype not in (torch.float16, torch.bfloat16): x = x.to(torch.bfloat16) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 9a841909ae8d..ef8c3e02619e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -260,7 +260,7 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( - NPUMXFP8LinearMethod + NPUMXFP8LinearMethod, ) self.npu_method = NPUMXFP8LinearMethod() From a29bb3d6ad91d0f2f856e644989398cb36769c23 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 09:21:20 +0800 Subject: [PATCH 10/29] :pencil2: fix(diffusion/mxfp8): address review comments on ModelSlimMXFP8Scheme - Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode) - Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format - Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape - Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv) - Improve flatten-to-2D comment to explain NPU kernel requirement --- .../runtime/layers/quantization/modelslim.py | 4 +--- .../layers/quantization/modelslim_mxfp8_scheme.py | 12 ++++++------ python/sglang/multimodal_gen/tools/wan_repack.py | 4 ++-- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index 4d43c50e3675..4a9b96f9c9c9 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -124,9 +124,7 @@ def _get_scheme_from_parts( ModelSlimMXFP8Scheme, ) - return ModelSlimMXFP8Scheme( - quant_config=self.quant_description, prefix=layer_name - ) + return ModelSlimMXFP8Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_scheme( diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py index a5c92f039e94..c12464d691f7 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -4,7 +4,7 @@ uint8 scales) and runs MXFP8 matmul at inference. """ -from typing import Dict, List, Optional +from typing import List, Optional import torch import torch_npu @@ -20,10 +20,6 @@ class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): - def __init__(self, quant_config: Dict[str, any], prefix: str): - self.quant_config = quant_config - self.prefix = prefix - def create_weights( self, layer: torch.nn.Module, @@ -49,7 +45,11 @@ def create_weights( ) layer.register_parameter("weight", weight) - # msmodelslim exports weight_scale as uint8, shape [out, in/32] + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # NOTE: This parameter is intentionally named "weight_scale" (not + # "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader + # matches parameter names to checkpoint keys, and msmodelslim checkpoints + # store this tensor under the key ".weight_scale". scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE weight_scale = GroupQuantScaleParameter( data=torch.empty( diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index d7951e6c3c5f..308b229d8593 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -90,8 +90,8 @@ def get_quant_subpath( return quant_path -def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> None: - dict[new_key] = dict.pop(old_key) +def update_dict_(d: Dict[str, Any], old_key: str, new_key: str) -> None: + d[new_key] = d.pop(old_key) def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: From 250fe65deda8476402ad002e43feb4b6a6e184f8 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 10:39:35 +0800 Subject: [PATCH 11/29] :adhesive_bandage: fix(diffusion): register --quantization CLI arg to avoid argparse ambiguity --- python/sglang/multimodal_gen/runtime/server_args.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 72c2cd5f63e9..60565998c005 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -727,6 +727,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Disable autocast for denoising loop and vae decoding in pipeline sampling", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help='Quantization method override (e.g. "mxfp8", "fp8", "modelslim"). ' + "When set, the transformer loader will use this instead of auto-detection.", + ) + # Nunchaku SVDQuant quantization parameters NunchakuSVDQuantArgs.add_cli_args(parser) From e146b03179e8d4215138b3f5aeb032083f74d8e7 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 11:02:44 +0800 Subject: [PATCH 12/29] :bug: fix(mxfp8_npu): move weight to current NPU device before quantization --- .../runtime/layers/quantization/mxfp8_npu.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index 1d4d88adb8ec..b3dd4612460d 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -111,8 +111,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) - # Ensure weight is on NPU before calling npu_dynamic_mx_quant - assert weight_fp.is_npu + # Move weight to NPU if needed. We intentionally use a conditional + # move rather than an assert because `dit_cpu_offload` defaults to + # True in ServerArgs, which causes fsdp_load to move every parameter + # back to CPU after loading (even when the target device is NPU). + # npu_dynamic_mx_quant requires an NPU tensor, so we must transfer + # here. The quantized fp8 weights produced below will remain on NPU + # for inference; if the model still needs to be offloaded after + # quantization (e.g. very large model on a small NPU), a higher-level + # offload pass can move them back afterwards. + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") # Online MXFP8 quantisation of weights (block_size=32) qw, w_scale = torch_npu.npu_dynamic_mx_quant( From 711bb8b8ec0227c150ceff643b0ddecd6cc5c4be Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 14:34:55 +0800 Subject: [PATCH 13/29] :rewind: revert(llm): remove LLM MXFP8 online quantization (Path B) for separate PR Revert LLM-side MXFP8 changes to split into a separate PR. This branch now only contains Wan2.2 Diffusion MXFP8 changes. Reverted files: - fp8.py: removed MXFP8LinearAscendMethod class and NPU branch - mxfp8_method_npu.py: deleted (NPU MXFP8 linear method) - test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test) LLM MXFP8 code preserved on junlin_llm branch. --- .../npu/quantization/mxfp8_method_npu.py | 82 ------------ python/sglang/srt/layers/quantization/fp8.py | 126 +++--------------- .../ascend/test_ascend_mxfp8_quantization.py | 103 -------------- 3 files changed, 18 insertions(+), 293 deletions(-) delete mode 100644 python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py delete mode 100644 test/srt/ascend/test_ascend_mxfp8_quantization.py diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py deleted file mode 100644 index c83b77cad7f4..000000000000 --- a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Optional - -import torch -import torch_npu -from torch.nn.parameter import Parameter - - -class NPUMXFP8LinearMethod: - """Ascend NPU MXFP8 weight processing and kernel calls. - - This class handles NPU-specific operations: - - process_weights_after_loading: dtype casting and online quantization - - apply: dynamic activation quantization + MXFP8 matmul - - Weight creation and config management are handled by - MXFP8LinearAscendMethod in fp8.py. - """ - - MXFP8_BLOCK_SIZE = 32 - - def process_weights_after_loading( - self, layer: torch.nn.Module, is_serialized: bool - ): - if is_serialized: - # Checkpoint already has fp8 weights + uint8 scales. - # Ensure weight is float8_e4m3fn. - if layer.weight.data.dtype != torch.float8_e4m3fn: - layer.weight = Parameter( - torch_npu.npu_dtype_cast( - layer.weight.data, torch_npu.float8_e4m3fn - ), - requires_grad=False, - ) - else: - layer.weight.requires_grad_(False) - - # Scale is already uint8 (e8m0fnu), keep as-is. - layer.weight_scale_inv.requires_grad_(False) - else: - # Online quantization: quantize FP16/BF16 weights to MXFP8. - weight_fp = layer.weight.data - if weight_fp.dtype not in (torch.float16, torch.bfloat16): - weight_fp = weight_fp.to(torch.bfloat16) - - if not weight_fp.is_npu: - weight_fp = weight_fp.npu() - - qw, w_scale = torch_npu.npu_dynamic_mx_quant( - weight_fp, dst_type=torch_npu.float8_e4m3fn - ) - layer.weight = Parameter(qw, requires_grad=False) - layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - original_dtype = x.dtype - if original_dtype not in (torch.float16, torch.bfloat16): - x = x.to(torch.bfloat16) - original_dtype = torch.bfloat16 - - # Dynamic MXFP8 activation quantization (block_size=32) - qx, input_scale = torch_npu.npu_dynamic_mx_quant( - x, dst_type=torch_npu.float8_e4m3fn - ) - - # MXFP8 quantized matmul - output = torch_npu.npu_quant_matmul( - qx, - layer.weight.transpose(0, 1), - layer.weight_scale_inv.transpose(0, 1), - scale_dtype=torch_npu.float8_e8m0fnu, - pertoken_scale=input_scale, - pertoken_scale_dtype=torch_npu.float8_e8m0fnu, - bias=bias.to(torch.float32) if bias is not None else None, - output_dtype=original_dtype, - group_sizes=[1, 1, self.MXFP8_BLOCK_SIZE], - ) - return output diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b59cd86d6ba0..1e12baff13cf 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -57,10 +57,7 @@ requant_weight_ue8m0_inplace, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod -from sglang.srt.layers.quantization.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) +from sglang.srt.layers.quantization.marlin_utils_fp8 import prepare_fp8_layer_for_marlin from sglang.srt.layers.quantization.unquant import ( UnquantizedFusedMoEMethod, UnquantizedLinearMethod, @@ -225,8 +222,6 @@ def get_quant_method( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() - if _is_npu and self.use_mxfp8: - return MXFP8LinearAscendMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( @@ -244,99 +239,6 @@ def get_scaled_act_names(self) -> List[str]: return [] -class MXFP8LinearAscendMethod(LinearMethodBase): - """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. - - Supports two modes: - - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 - at weight loading time. - - Offline quantization: loads pre-quantized FP8 weights with block scales - from a serialized checkpoint. - - Weight creation is handled here; weight processing and kernel calls are - delegated to NPUMXFP8LinearMethod in the NPU hardware backend. - """ - - MXFP8_BLOCK_SIZE = 32 - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( - NPUMXFP8LinearMethod, - ) - - self.npu_method = NPUMXFP8LinearMethod() - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") - - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - - is_serialized = self.quant_config.is_checkpoint_fp8_serialized - - # Weight: fp8 if serialized checkpoint, else original dtype (will be - # quantized in process_weights_after_loading) - weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - if is_serialized: - # Block scale: one scale per block of 32 elements along input dim. - # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. - block_k = self.MXFP8_BLOCK_SIZE - scale_cols = (input_size_per_partition + block_k - 1) // block_k - scale = BlockQuantScaleParameter( - data=torch.zeros( - output_size_per_partition, - scale_cols, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - scale.format_ue8m0 = True - layer.register_parameter("weight_scale_inv", scale) - else: - layer.register_parameter("weight_scale_inv", None) - - def process_weights_after_loading(self, layer: torch.nn.Module): - self.npu_method.process_weights_after_loading( - layer, self.quant_config.is_checkpoint_fp8_serialized - ) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.npu_method.apply(layer, x, bias) - - class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. @@ -741,7 +643,7 @@ def apply( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_marlin: - return apply_fp8_marlin_linear( + return torch.ops.sglang.apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, @@ -1096,15 +998,23 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: w2_weight_scale, requires_grad=False ) layer.w2_input_scale = None - - if _use_aiter: + if _use_aiter: + # add this section for MI300 + # Pre-shuffle weights + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) + elif _use_aiter: # Pre-shuffle weights - t = shuffle_weight(layer.w13_weight, (16, 16)) - layer.w13_weight.copy_(t) - del t - t = shuffle_weight(layer.w2_weight, (16, 16)) - layer.w2_weight.copy_(t) - del t + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) elif _is_cpu: assert ( _is_cpu_amx_available diff --git a/test/srt/ascend/test_ascend_mxfp8_quantization.py b/test/srt/ascend/test_ascend_mxfp8_quantization.py deleted file mode 100644 index e7af4ff97f1a..000000000000 --- a/test/srt/ascend/test_ascend_mxfp8_quantization.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Usage: -python3 -m unittest test_ascend_mxfp8_quantization.TestAscendMXFP8.test_gsm8k -""" - -import os -import time -import unittest -from types import SimpleNamespace -from urllib.parse import urlparse - -import requests - -from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_ci, - popen_launch_server, -) - -if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" -DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( - 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 -) -DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" - - -class TestAscendMXFP8(CustomTestCase): - """Test online MXFP8 quantization (--quantization mxfp8) on Ascend NPU.""" - - @classmethod - def setUpClass(cls): - cls.model = "Qwen/Qwen2.5-0.5B-Instruct" - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--quantization", - "mxfp8", - "--device", - "npu", - "--attention-backend", - "ascend", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - url = urlparse(self.base_url) - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{url.hostname}", - port=int(url.port), - ) - metrics = run_eval(args) - print(metrics) - self.assertGreaterEqual(metrics["accuracy"], 0.25) - self.assertGreaterEqual(metrics["output_throughput"], 500) - - def run_decode(self, max_new_tokens): - response = requests.post( - self.base_url + "/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - "ignore_eos": True, - }, - ) - return response.json() - - def test_throughput(self): - max_tokens = 256 - - tic = time.perf_counter() - res = self.run_decode(max_tokens) - tok = time.perf_counter() - print(res["text"]) - throughput = max_tokens / (tok - tic) - print(f"Throughput: {throughput} tokens/s") - - if is_in_ci(): - self.assertGreaterEqual(throughput, 20) - - -if __name__ == "__main__": - unittest.main() From 615dda782c2dd05b8282f62a73cd1da9c9704dc1 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 26 Mar 2026 11:36:36 +0800 Subject: [PATCH 14/29] :sparkles: feat(diffusion/mxfp4): add MXFP4 online quantization for Diffusion on Ascend NPU --- .../runtime/layers/quantization/__init__.py | 4 +- .../runtime/layers/quantization/mxfp4_npu.py | 180 ++++++++++++++++++ 2 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py index 9967879148e3..0eb464a8ca78 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -10,9 +10,10 @@ ModelOptFp4Config, ) from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig +from sglang.multimodal_gen.runtime.layers.quantization.mxfp4_npu import MXFP4Config from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config -QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim", "mxfp8"] +QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim", "mxfp8", "mxfp4"] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -22,6 +23,7 @@ "modelslim": ModelSlimConfig, "fp8": Fp8Config, "mxfp8": MXFP8Config, + "mxfp4": MXFP4Config, } diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py new file mode 100644 index 000000000000..5a3e7cf2eeb2 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -0,0 +1,180 @@ +"""Online MXFP4 quantization for Diffusion models on Ascend NPU. + +Provides ``MXFP4Config`` (registered as ``"mxfp4"``) and +``NPUMXFP4DiffusionLinearMethod`` which quantises FP16/BF16 weights to MXFP4 +at load time using dual-level MX quantization and uses +``npu_dynamic_dual_level_mx_quant`` + ``npu_dual_level_quant_matmul`` for +inference. + +NOTE: Online weight quantization via ``npu_dynamic_dual_level_mx_quant`` is +experimental. MindIE-SD only uses an offline (pre-quantized) path for MXFP4 +weights. The online path quantizes FP16/BF16 weights at load time, which may +produce different numerical results than the offline calibrated path. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import torch +import torch_npu +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class MXFP4Config(QuantizationConfig): + """Config for online MXFP4 quantization on Ascend NPU (Diffusion).""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> str: + return "mxfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU, not CUDA + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MXFP4Config": + return cls() + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + return NPUMXFP4DiffusionLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class NPUMXFP4DiffusionLinearMethod(LinearMethodBase): + """Ascend NPU MXFP4 linear method for Diffusion models (dual-level). + + Online mode: loads FP16/BF16 weights → quantises to MXFP4 at load time + via ``npu_dynamic_dual_level_mx_quant``. + Inference: dynamic dual-level MXFP4 activation quant + dual-level matmul. + + Reference: MindIE-SD ``W4A4MXFP4DualQuantLinear`` (offline path only). + """ + + def __init__(self, quant_config: MXFP4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move weight to NPU if needed. dit_cpu_offload defaults to True in + # ServerArgs, which causes fsdp_load to move parameters back to CPU + # after loading. npu_dynamic_dual_level_mx_quant requires an NPU tensor. + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online dual-level MXFP4 weight quantisation. + # NOTE: This is experimental — MindIE-SD only has an offline path for + # MXFP4 weights. We assume npu_dynamic_dual_level_mx_quant can also + # quantise weights (not just activations). + # Returns: (qw, w_dual_scale, w_scale) + # qw — quantized weight in float4_e2m1fn_x2 (2 FP4 packed/byte) + # w_dual_scale — L0-level scale (goes to pos 3 in npu_dual_level_quant_matmul) + # w_scale — L1-level scale (goes to pos 5 in npu_dual_level_quant_matmul) + qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + weight_fp, smooth_scale=None + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_dual_scale = Parameter(w_dual_scale, requires_grad=False) + layer.weight_scale = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] for the quantization operators + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic dual-level MXFP4 activation quantisation + qx, act_l0_scale, act_l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + x_2d, smooth_scale=None + ) + + # Dual-level MXFP4 matmul + # Arg order: act_quant, weight_quant, act_l0_scale, weight_dual_scale, + # act_l1_scale, weight_scale, bias=, output_dtype= + # NOTE: weight is NOT transposed (unlike MXFP8's npu_quant_matmul). + output = torch_npu.npu_dual_level_quant_matmul( + qx, + layer.weight, + act_l0_scale, + layer.weight_dual_scale, + act_l1_scale, + layer.weight_scale, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output From 99167a3391d376bd824f0fde9a0d71abe3c1878a Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 26 Mar 2026 14:07:38 +0800 Subject: [PATCH 15/29] :sparkles: feat(diffusion/mxfp4): add MXFP4 ModelSlim offline quantization loading - Add ModelSlimMXFP4Scheme for loading msmodelslim pre-quantized MXFP4 weights - Support dual-level quantization via npu_dual_level_quant_matmul - Register W4A4_MXFP4 quant type in modelslim.py dispatcher - Handle FP4 packed weight casting and scale transformations Weights: float8_e4m3fn (FP4 packed) [out, in/2] Scales: uint8 (e8m0+127) [out, in/32] + bfloat16 dual [out, in/64] --- .../runtime/layers/quantization/modelslim.py | 6 + .../quantization/modelslim_mxfp4_scheme.py | 143 ++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index 4a9b96f9c9c9..af292ca37f13 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -125,6 +125,12 @@ def _get_scheme_from_parts( ) return ModelSlimMXFP8Scheme() + elif quant_type == "W4A4_MXFP4": + from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp4_scheme import ( + ModelSlimMXFP4Scheme, + ) + + return ModelSlimMXFP4Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_scheme( diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py new file mode 100644 index 000000000000..e9be95f9d04d --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -0,0 +1,143 @@ +"""ModelSlim MXFP4 scheme for pre-quantized weight inference on Ascend NPU. + +Loads weights pre-quantized by msmodelslim (float8_e4m3fn as FP4 packed +container, uint8 scales, bfloat16 dual scales) and runs MXFP4 dual-level +matmul at inference via npu_dual_level_quant_matmul. + +Reference: MindIE-SD W4A4MXFP4DualQuantLinear +(MindIE-SD/mindiesd/quantization/layer.py) +""" + +from typing import List, Optional + +import torch +import torch_npu + +from sglang.multimodal_gen.runtime.models.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP4_BLOCK_SIZE = 32 + + +class ModelSlimMXFP4Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as float8_e4m3fn, shape [out, in/2]. + # Two FP4 (E2M1) values are packed into one float8_e4m3fn byte. + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition // 2), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # Stored as e8m0 scale + 127 offset. + scale_dim = input_size_per_partition // MXFP4_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # L1 (coarse) scale for dual-level quantization matmul. + # MindIE-SD loads this as [out, in/64, 1], then squeeze(-1) + transpose. + # The dual_scale groups every 2 L0 blocks (64 elements) into one L1 block. + # TODO: The exact shape and dtype depend on the checkpoint export tool. + # msmodelslim's current version may not export this field; ensure the + # checkpoint includes weight_dual_scale for dual-level matmul support. + dual_scale_dim = scale_dim // 2 # in/32 / 2 = in/64 + weight_dual_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, dual_scale_dim), + dtype=torch.bfloat16, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_dual_scale", weight_dual_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # Cast weight from fp8 container to FP4 packed format + weight = layer.weight.data + if not weight.is_npu: + weight = weight.to(f"npu:{torch.npu.current_device()}") + weight = torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + + # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] + # The dual-level matmul API expects L0 scales in this 3D format + weight_scale = layer.weight_scale.data + weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + # Transform weight_dual_scale: [out, in/64] -> [in/64, out] + # MindIE-SD does squeeze(-1).transpose(0,1); we skip squeeze since + # our parameter is already 2D [out, in/64] + weight_dual_scale = layer.weight_dual_scale.data + weight_dual_scale = weight_dual_scale.transpose(0, 1).contiguous() + layer.weight_dual_scale = torch.nn.Parameter( + weight_dual_scale, requires_grad=False + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D for npu_dynamic_dual_level_mx_quant + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dual-level MXFP4 activation quantization + x1, l0_scale, l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + x_2d, smooth_scale=None + ) + + # Dual-level MXFP4 matmul + output = torch_npu.npu_dual_level_quant_matmul( + x1, + layer.weight, + l0_scale, + layer.weight_dual_scale, + l1_scale, + layer.weight_scale, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + ) + + # Restore original shape + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) From 9baae1c078eedeb9a8656335fcbe8eadb232c892 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Sat, 28 Mar 2026 13:43:33 +0800 Subject: [PATCH 16/29] :bug: fix(diffusion/mxfp4): add NZ format cast and dual_scale transpose for matmul --- .../runtime/layers/quantization/mxfp4_npu.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py index 5a3e7cf2eeb2..89cd38d2bfa6 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -134,6 +134,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( weight_fp, smooth_scale=None ) + + # npu_dual_level_quant_matmul requires x2 (weight) in FRACTAL_NZ format. + # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param + qw = torch_npu.npu_format_cast(qw.view(torch.int8), 29) + + # x2Level0Scale must be [in/level0_block_size, out] — transpose from + # the [out, in/level0_block_size] shape returned by the quant op. + # Reference: MindIE-SD layer.py:409 + w_dual_scale = w_dual_scale.squeeze(-1).transpose(0, 1).contiguous() + layer.weight = Parameter(qw, requires_grad=False) layer.weight_dual_scale = Parameter(w_dual_scale, requires_grad=False) layer.weight_scale = Parameter(w_scale, requires_grad=False) From a543f681f8d22c601953e3a948e50c7914b1ba60 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Sat, 28 Mar 2026 13:43:33 +0800 Subject: [PATCH 17/29] :bug: fix(diffusion/mxfp4): add NZ format cast and dual_scale transpose for matmul --- .../runtime/layers/quantization/mxfp4_npu.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py index 5a3e7cf2eeb2..89cd38d2bfa6 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -134,6 +134,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( weight_fp, smooth_scale=None ) + + # npu_dual_level_quant_matmul requires x2 (weight) in FRACTAL_NZ format. + # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param + qw = torch_npu.npu_format_cast(qw.view(torch.int8), 29) + + # x2Level0Scale must be [in/level0_block_size, out] — transpose from + # the [out, in/level0_block_size] shape returned by the quant op. + # Reference: MindIE-SD layer.py:409 + w_dual_scale = w_dual_scale.squeeze(-1).transpose(0, 1).contiguous() + layer.weight = Parameter(qw, requires_grad=False) layer.weight_dual_scale = Parameter(w_dual_scale, requires_grad=False) layer.weight_scale = Parameter(w_scale, requires_grad=False) From e37324a6a3c23afe1c463dfa4cb48b497c61c3a3 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Mon, 30 Mar 2026 09:15:02 +0800 Subject: [PATCH 18/29] :sparkles: feat(mxfp4/modelslim): support W4A4_MXFP4_DUALSCALE offline inference 1. Dispatch W4A4_MXFP4_DUALSCALE type to ModelSlimMXFP4Scheme in modelslim.py\n2. Add .linear. key stripping in wan_repack RENAME_DICT for MXFP4 checkpoints\n3. Support multi-shard safetensors loading in load_sharded_safetensors --- .../runtime/layers/quantization/modelslim.py | 2 +- python/sglang/multimodal_gen/tools/wan_repack.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index af292ca37f13..3867ea138c2a 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -125,7 +125,7 @@ def _get_scheme_from_parts( ) return ModelSlimMXFP8Scheme() - elif quant_type == "W4A4_MXFP4": + elif quant_type in ("W4A4_MXFP4", "W4A4_MXFP4_DUALSCALE"): from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp4_scheme import ( ModelSlimMXFP4Scheme, ) diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 308b229d8593..1e395cedc42b 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -52,6 +52,9 @@ "attn2.to_k_img": "attn2.add_k_proj", "attn2.to_v_img": "attn2.add_v_proj", "attn2.norm_k_img": "attn2.norm_added_k", + # MXFP4 msmodelslim wraps Linear layers with a `.linear.` subpath; + # strip it so keys match the SGLang model parameters. + ".linear.": ".", } SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] @@ -98,13 +101,10 @@ def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: candidates = sorted(directory.glob(pattern)) if not candidates: raise FileNotFoundError(f"No file matching '{pattern}' found in {directory}") - if len(candidates) > 1: - raise FileNotFoundError( - f"Multiple files matching '{pattern}' found in {directory}: {candidates}" - ) state_dict = {} - state_dict.update(load_file(candidates[0])) + for f in candidates: + state_dict.update(load_file(f)) return state_dict From 5f160757adb184fcb94705672dabc858a04b7748 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Mon, 30 Mar 2026 14:47:48 +0800 Subject: [PATCH 19/29] :sparkles: feat(quantization): support MXFP4 DualScale offline quantization - Add W4A4_MXFP4_DUALSCALE type to modelslim scheme dispatcher - Support .linear. key stripping in wan_repack for MXFP4 msmodelslim exports - Support multi-shard safetensors loading in repack tool - Fix modelslim quantization config loading from component directory - Add detailed error messages for unsupported quantization schemes --- .../runtime/layers/quantization/modelslim.py | 5 ++++- .../loader/component_loaders/transformer_loader.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index 3867ea138c2a..3f2d625d504f 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -131,7 +131,10 @@ def _get_scheme_from_parts( ) return ModelSlimMXFP4Scheme() - raise NotImplementedError("No modelslim compatible scheme was found.") + raise NotImplementedError( + f"No modelslim compatible scheme was found for layer '{layer_name}'. " + f"quant_description['{layer_name}.weight'] = '{quant_type}'" + ) def get_scheme( self, layer: torch.nn.Module, layer_name: Optional[str] = None diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py index 50d2fa841b29..747c67d98162 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py @@ -91,6 +91,15 @@ def _resolve_quant_config( get_quantization_config, ) + # modelslim requires a per-layer quant description file; load it from + # the component directory rather than returning an empty config. + if server_args.quantization == "modelslim": + from sglang.multimodal_gen.runtime.utils.quantization_utils import ( + find_quant_modelslim_config, + ) + + return find_quant_modelslim_config(hf_config, component_model_path) + quant_cls = get_quantization_config(server_args.quantization) return quant_cls.from_config({}) From 9e2442ac60b9363a745651f3c966ed4e8810e83c Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Mon, 30 Mar 2026 15:10:23 +0800 Subject: [PATCH 20/29] :bug: fix(diffusion/loader): load modelslim description when quantization flag is explicit When --quantization modelslim is explicitly passed, the loader must load the per-layer quant_model_description.json from the transformer directory rather than creating an empty config. This ensures ModelSlimConfig receives the quantization type mappings required for proper scheme dispatch. --- .../runtime/loader/component_loaders/transformer_loader.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py index 747c67d98162..607902022755 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py @@ -28,7 +28,6 @@ from sglang.multimodal_gen.runtime.utils.quantization_utils import ( build_nvfp4_config_from_safetensors_list, get_metadata_from_safetensors_file, - get_quant_config, get_quant_config_from_safetensors_metadata, ) from sglang.multimodal_gen.utils import PRECISION_TO_TYPE @@ -95,10 +94,10 @@ def _resolve_quant_config( # the component directory rather than returning an empty config. if server_args.quantization == "modelslim": from sglang.multimodal_gen.runtime.utils.quantization_utils import ( - find_quant_modelslim_config, + get_quant_config, ) - return find_quant_modelslim_config(hf_config, component_model_path) + return get_quant_config(hf_config, component_model_path) quant_cls = get_quantization_config(server_args.quantization) return quant_cls.from_config({}) From 3c2c53401115e03b7dd3599945f52299965913b5 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 09:00:42 +0800 Subject: [PATCH 21/29] :bug: fix(modelslim/mxfp4): correct weight shape and dual scale format per msmodelslim export - weight: [out, in] float8_e4m3fn (not [out, in/2]) - weight_dual_scale: [out, in/512, 1] float32 (not [out, in/64] bfloat16) L1 scale groups 16 L0 blocks = 512 elements - Fix create_weights allocation and process_weights_after_loading transforms to match actual checkpoint tensor formats from msmodelslim --- .../quantization/modelslim_mxfp4_scheme.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py index e9be95f9d04d..b5e91f8d8c77 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -1,9 +1,13 @@ """ModelSlim MXFP4 scheme for pre-quantized weight inference on Ascend NPU. -Loads weights pre-quantized by msmodelslim (float8_e4m3fn as FP4 packed -container, uint8 scales, bfloat16 dual scales) and runs MXFP4 dual-level +Loads weights pre-quantized by msmodelslim and runs MXFP4 dual-level matmul at inference via npu_dual_level_quant_matmul. +Checkpoint tensor formats (verified from msmodelslim export): + weight: [out, in] float8_e4m3fn (FP4 data in fp8 container) + weight_scale: [out, in/32] uint8 (L0 block scales, e8m0+127) + weight_dual_scale:[out, in/512, 1] float32 (L1 coarse scales) + Reference: MindIE-SD W4A4MXFP4DualQuantLinear (MindIE-SD/mindiesd/quantization/layer.py) """ @@ -20,6 +24,9 @@ from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme MXFP4_BLOCK_SIZE = 32 +# L1 (dual) scale groups this many L0 blocks together. +# L1 block covers 16 * 32 = 512 elements. +MXFP4_DUAL_LEVEL_RATIO = 16 class ModelSlimMXFP4Scheme(ModelSlimLinearScheme): @@ -37,11 +44,12 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) - # msmodelslim exports weight as float8_e4m3fn, shape [out, in/2]. - # Two FP4 (E2M1) values are packed into one float8_e4m3fn byte. + # msmodelslim exports weight as float8_e4m3fn, shape [out, in]. + # Each byte is a float8 container for FP4 data; the actual FP4 packing + # (npu_dtype_cast → float4_e2m1fn_x2) happens in process_weights_after_loading. weight = ModelWeightParameter( data=torch.empty( - (output_size_per_partition, input_size_per_partition // 2), + (output_size_per_partition, input_size_per_partition), dtype=torch.float8_e4m3fn, ), input_dim=1, @@ -50,8 +58,7 @@ def create_weights( ) layer.register_parameter("weight", weight) - # msmodelslim exports weight_scale as uint8, shape [out, in/32]. - # Stored as e8m0 scale + 127 offset. + # L0 block scale: uint8 [out, in/32], e8m0 scale with +127 offset. scale_dim = input_size_per_partition // MXFP4_BLOCK_SIZE weight_scale = GroupQuantScaleParameter( data=torch.empty( @@ -64,17 +71,13 @@ def create_weights( ) layer.register_parameter("weight_scale", weight_scale) - # L1 (coarse) scale for dual-level quantization matmul. - # MindIE-SD loads this as [out, in/64, 1], then squeeze(-1) + transpose. - # The dual_scale groups every 2 L0 blocks (64 elements) into one L1 block. - # TODO: The exact shape and dtype depend on the checkpoint export tool. - # msmodelslim's current version may not export this field; ensure the - # checkpoint includes weight_dual_scale for dual-level matmul support. - dual_scale_dim = scale_dim // 2 # in/32 / 2 = in/64 + # L1 (coarse) dual scale: float32 [out, in/512, 1]. + # Each L1 block covers 16 L0 blocks (512 elements). + dual_scale_dim = scale_dim // MXFP4_DUAL_LEVEL_RATIO weight_dual_scale = GroupQuantScaleParameter( data=torch.empty( - (output_size_per_partition, dual_scale_dim), - dtype=torch.bfloat16, + (output_size_per_partition, dual_scale_dim, 1), + dtype=torch.float32, ), input_dim=1, output_dim=0, @@ -96,11 +99,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) - # Transform weight_dual_scale: [out, in/64] -> [in/64, out] - # MindIE-SD does squeeze(-1).transpose(0,1); we skip squeeze since - # our parameter is already 2D [out, in/64] + # Transform weight_dual_scale: [out, in/512, 1] -> [in/512, out] weight_dual_scale = layer.weight_dual_scale.data - weight_dual_scale = weight_dual_scale.transpose(0, 1).contiguous() + weight_dual_scale = weight_dual_scale.squeeze(-1).transpose(0, 1).contiguous() layer.weight_dual_scale = torch.nn.Parameter( weight_dual_scale, requires_grad=False ) From 9f30028bac9a71811f2673e22a8aec3297801019 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 09:26:37 +0800 Subject: [PATCH 22/29] :bug: fix(mxfp4/modelslim): fix runtime error --- .../layers/quantization/modelslim_mxfp4_scheme.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py index b5e91f8d8c77..36a176660da3 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -71,9 +71,9 @@ def create_weights( ) layer.register_parameter("weight_scale", weight_scale) - # L1 (coarse) dual scale: float32 [out, in/512, 1]. - # Each L1 block covers 16 L0 blocks (512 elements). - dual_scale_dim = scale_dim // MXFP4_DUAL_LEVEL_RATIO + # L1 (coarse) scale for dual-level quantization matmul. + # Each L1 block covers MXFP4_DUAL_LEVEL_RATIO L0 blocks = 16 * 32 = 512 elements. + dual_scale_dim = scale_dim // MXFP4_DUAL_LEVEL_RATIO # in/32 / 16 = in/512 weight_dual_scale = GroupQuantScaleParameter( data=torch.empty( (output_size_per_partition, dual_scale_dim, 1), @@ -91,6 +91,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): if not weight.is_npu: weight = weight.to(f"npu:{torch.npu.current_device()}") weight = torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2) + # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format 29). + # Reference: mxfp4_npu.py process_weights_after_loading + weight = torch_npu.npu_format_cast(weight.view(torch.int8), 29) layer.weight = torch.nn.Parameter(weight, requires_grad=False) # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] From 4b694119e20b740b4b8738b5016962d739a8ff23 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 10:02:10 +0800 Subject: [PATCH 23/29] =?UTF-8?q?=F0=9F=90=9B=20fix(mxfp4/modelslim):=20fi?= =?UTF-8?q?x=20mosaic=20issue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../layers/quantization/modelslim_mxfp4_scheme.py | 10 ++++++++-- .../runtime/layers/quantization/mxfp4_npu.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py index 36a176660da3..62109eebfae6 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -92,18 +92,24 @@ def process_weights_after_loading(self, layer: torch.nn.Module): weight = weight.to(f"npu:{torch.npu.current_device()}") weight = torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2) # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format 29). - # Reference: mxfp4_npu.py process_weights_after_loading - weight = torch_npu.npu_format_cast(weight.view(torch.int8), 29) + # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param + weight = torch_npu.npu_format_cast(weight.view(torch.int8), 29, torch.int8) layer.weight = torch.nn.Parameter(weight, requires_grad=False) # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] # The dual-level matmul API expects L0 scales in this 3D format weight_scale = layer.weight_scale.data + if not weight_scale.is_npu: + weight_scale = weight_scale.to(f"npu:{torch.npu.current_device()}") weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) # Transform weight_dual_scale: [out, in/512, 1] -> [in/512, out] weight_dual_scale = layer.weight_dual_scale.data + if not weight_dual_scale.is_npu: + weight_dual_scale = weight_dual_scale.to( + f"npu:{torch.npu.current_device()}" + ) weight_dual_scale = weight_dual_scale.squeeze(-1).transpose(0, 1).contiguous() layer.weight_dual_scale = torch.nn.Parameter( weight_dual_scale, requires_grad=False diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py index 89cd38d2bfa6..f302d9738d5e 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -137,7 +137,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # npu_dual_level_quant_matmul requires x2 (weight) in FRACTAL_NZ format. # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param - qw = torch_npu.npu_format_cast(qw.view(torch.int8), 29) + qw = torch_npu.npu_format_cast(qw.view(torch.int8), 29, torch.int8) # x2Level0Scale must be [in/level0_block_size, out] — transpose from # the [out, in/level0_block_size] shape returned by the quant op. From 96032d9e427c46c7fc7173c8b95b6c3cc329ddbb Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 10:08:03 +0800 Subject: [PATCH 24/29] =?UTF-8?q?=F0=9F=90=9B=20fix(mxfp4/modelslim):=20fi?= =?UTF-8?q?x=20runtime=20error?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../runtime/layers/quantization/modelslim_mxfp4_scheme.py | 4 +++- .../multimodal_gen/runtime/layers/quantization/mxfp4_npu.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py index 62109eebfae6..9a79e19163a4 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -93,7 +93,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): weight = torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2) # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format 29). # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param - weight = torch_npu.npu_format_cast(weight.view(torch.int8), 29, torch.int8) + weight = torch_npu.npu_format_cast( + weight.view(torch.int8), 29, customize_dtype=torch.int8 + ) layer.weight = torch.nn.Parameter(weight, requires_grad=False) # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py index f302d9738d5e..cbb4f58488a8 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -137,7 +137,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # npu_dual_level_quant_matmul requires x2 (weight) in FRACTAL_NZ format. # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param - qw = torch_npu.npu_format_cast(qw.view(torch.int8), 29, torch.int8) + qw = torch_npu.npu_format_cast( + qw.view(torch.int8), 29, customize_dtype=torch.int8 + ) # x2Level0Scale must be [in/level0_block_size, out] — transpose from # the [out, in/level0_block_size] shape returned by the quant op. From a252146195443ad698c44388c36dee7b8baa9f26 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 10:32:46 +0800 Subject: [PATCH 25/29] :chart_with_upwards_trend: Add temporary debugging logs --- .../quantization/modelslim_mxfp4_scheme.py | 48 ++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py index 9a79e19163a4..6eda4e6edb26 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -86,33 +86,79 @@ def create_weights( layer.register_parameter("weight_dual_scale", weight_dual_scale) def process_weights_after_loading(self, layer: torch.nn.Module): + import logging + + _log = logging.getLogger(__name__) + # Cast weight from fp8 container to FP4 packed format weight = layer.weight.data + _log.warning( + "[MXFP4-DBG] weight BEFORE cast: shape=%s dtype=%s device=%s", + weight.shape, + weight.dtype, + weight.device, + ) if not weight.is_npu: weight = weight.to(f"npu:{torch.npu.current_device()}") weight = torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2) + _log.warning( + "[MXFP4-DBG] weight AFTER npu_dtype_cast: shape=%s dtype=%s", + weight.shape, + weight.dtype, + ) + weight_int8_view = weight.view(torch.int8) + _log.warning( + "[MXFP4-DBG] weight.view(int8): shape=%s dtype=%s", + weight_int8_view.shape, + weight_int8_view.dtype, + ) # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format 29). # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param weight = torch_npu.npu_format_cast( - weight.view(torch.int8), 29, customize_dtype=torch.int8 + weight_int8_view, 29, customize_dtype=torch.int8 + ) + _log.warning( + "[MXFP4-DBG] weight AFTER format_cast(NZ): shape=%s dtype=%s", + weight.shape, + weight.dtype, ) layer.weight = torch.nn.Parameter(weight, requires_grad=False) # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] # The dual-level matmul API expects L0 scales in this 3D format weight_scale = layer.weight_scale.data + _log.warning( + "[MXFP4-DBG] weight_scale BEFORE reshape: shape=%s dtype=%s", + weight_scale.shape, + weight_scale.dtype, + ) if not weight_scale.is_npu: weight_scale = weight_scale.to(f"npu:{torch.npu.current_device()}") weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + _log.warning( + "[MXFP4-DBG] weight_scale AFTER reshape: shape=%s dtype=%s", + weight_scale.shape, + weight_scale.dtype, + ) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) # Transform weight_dual_scale: [out, in/512, 1] -> [in/512, out] weight_dual_scale = layer.weight_dual_scale.data + _log.warning( + "[MXFP4-DBG] weight_dual_scale BEFORE transform: shape=%s dtype=%s", + weight_dual_scale.shape, + weight_dual_scale.dtype, + ) if not weight_dual_scale.is_npu: weight_dual_scale = weight_dual_scale.to( f"npu:{torch.npu.current_device()}" ) weight_dual_scale = weight_dual_scale.squeeze(-1).transpose(0, 1).contiguous() + _log.warning( + "[MXFP4-DBG] weight_dual_scale AFTER transform: shape=%s dtype=%s", + weight_dual_scale.shape, + weight_dual_scale.dtype, + ) layer.weight_dual_scale = torch.nn.Parameter( weight_dual_scale, requires_grad=False ) From 31800790c5564a2f0277cd0cde0dd4c2153acc8d Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 10:44:26 +0800 Subject: [PATCH 26/29] =?UTF-8?q?=F0=9F=90=9B=20fix(mxfp4/modelslim):=20fi?= =?UTF-8?q?x=20mosaic=20issue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../quantization/modelslim_mxfp4_scheme.py | 95 ++++++++----------- .../sglang/multimodal_gen/tools/wan_repack.py | 3 + 2 files changed, 45 insertions(+), 53 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py index 6eda4e6edb26..153d330e31cf 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -5,8 +5,9 @@ Checkpoint tensor formats (verified from msmodelslim export): weight: [out, in] float8_e4m3fn (FP4 data in fp8 container) - weight_scale: [out, in/32] uint8 (L0 block scales, e8m0+127) - weight_dual_scale:[out, in/512, 1] float32 (L1 coarse scales) + weight_scale: [out, in/32] uint8 (L1 block scales, e8m0+127) + weight_dual_scale:[out, in/512, 1] float32 (L0 coarse scales) + mul_scale: [in] float32 (smooth quant activation scale) Reference: MindIE-SD W4A4MXFP4DualQuantLinear (MindIE-SD/mindiesd/quantization/layer.py) @@ -18,6 +19,7 @@ import torch_npu from sglang.multimodal_gen.runtime.models.parameter import ( + BasevLLMParameter, GroupQuantScaleParameter, ModelWeightParameter, ) @@ -58,7 +60,7 @@ def create_weights( ) layer.register_parameter("weight", weight) - # L0 block scale: uint8 [out, in/32], e8m0 scale with +127 offset. + # L1 block scale: uint8 [out, in/32], e8m0 scale with +127 offset. scale_dim = input_size_per_partition // MXFP4_BLOCK_SIZE weight_scale = GroupQuantScaleParameter( data=torch.empty( @@ -71,8 +73,8 @@ def create_weights( ) layer.register_parameter("weight_scale", weight_scale) - # L1 (coarse) scale for dual-level quantization matmul. - # Each L1 block covers MXFP4_DUAL_LEVEL_RATIO L0 blocks = 16 * 32 = 512 elements. + # L0 (coarse) scale for dual-level quantization matmul. + # Each L0 block covers MXFP4_DUAL_LEVEL_RATIO L1 blocks = 16 * 32 = 512 elements. dual_scale_dim = scale_dim // MXFP4_DUAL_LEVEL_RATIO # in/32 / 16 = in/512 weight_dual_scale = GroupQuantScaleParameter( data=torch.empty( @@ -85,84 +87,63 @@ def create_weights( ) layer.register_parameter("weight_dual_scale", weight_dual_scale) - def process_weights_after_loading(self, layer: torch.nn.Module): - import logging - - _log = logging.getLogger(__name__) + # Smooth quant activation scale (mul_scale) from NonFusionSmoothQuantWrapper. + # msmodelslim exports this as `.div.mul_scale` with shape [in]. + # After repack, it becomes `.mul_scale`. + # This is CRITICAL: the offline-quantized weights were calibrated with + # x * mul_scale applied to the activation. Omitting it causes mosaic output. + # Ref: MindIE-SD W4A4MXFP4DualQuantLinear.quant_matmul lines 385-386. + mul_scale = BasevLLMParameter( + data=torch.empty( + (input_size_per_partition,), + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + # If mul_scale is not in the checkpoint (e.g. non-smooth-quant model + # or old repack without .div. handling), initialize to ones so that + # x * 1.0 = x (no-op). fsdp_load.py checks this attribute. + mul_scale.missing_param_init = "ones" + layer.register_parameter("mul_scale", mul_scale) + def process_weights_after_loading(self, layer: torch.nn.Module): # Cast weight from fp8 container to FP4 packed format weight = layer.weight.data - _log.warning( - "[MXFP4-DBG] weight BEFORE cast: shape=%s dtype=%s device=%s", - weight.shape, - weight.dtype, - weight.device, - ) if not weight.is_npu: weight = weight.to(f"npu:{torch.npu.current_device()}") weight = torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2) - _log.warning( - "[MXFP4-DBG] weight AFTER npu_dtype_cast: shape=%s dtype=%s", - weight.shape, - weight.dtype, - ) - weight_int8_view = weight.view(torch.int8) - _log.warning( - "[MXFP4-DBG] weight.view(int8): shape=%s dtype=%s", - weight_int8_view.shape, - weight_int8_view.dtype, - ) # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format 29). # Reference: MindIE-SD W4A4MXFP4DualQuantLinear._init_dynamic_quant_param weight = torch_npu.npu_format_cast( - weight_int8_view, 29, customize_dtype=torch.int8 - ) - _log.warning( - "[MXFP4-DBG] weight AFTER format_cast(NZ): shape=%s dtype=%s", - weight.shape, - weight.dtype, + weight.view(torch.int8), 29, customize_dtype=torch.int8 ) layer.weight = torch.nn.Parameter(weight, requires_grad=False) # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] - # The dual-level matmul API expects L0 scales in this 3D format + # The dual-level matmul API expects L1 scales in this 3D format weight_scale = layer.weight_scale.data - _log.warning( - "[MXFP4-DBG] weight_scale BEFORE reshape: shape=%s dtype=%s", - weight_scale.shape, - weight_scale.dtype, - ) if not weight_scale.is_npu: weight_scale = weight_scale.to(f"npu:{torch.npu.current_device()}") weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) - _log.warning( - "[MXFP4-DBG] weight_scale AFTER reshape: shape=%s dtype=%s", - weight_scale.shape, - weight_scale.dtype, - ) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) # Transform weight_dual_scale: [out, in/512, 1] -> [in/512, out] weight_dual_scale = layer.weight_dual_scale.data - _log.warning( - "[MXFP4-DBG] weight_dual_scale BEFORE transform: shape=%s dtype=%s", - weight_dual_scale.shape, - weight_dual_scale.dtype, - ) if not weight_dual_scale.is_npu: weight_dual_scale = weight_dual_scale.to( f"npu:{torch.npu.current_device()}" ) weight_dual_scale = weight_dual_scale.squeeze(-1).transpose(0, 1).contiguous() - _log.warning( - "[MXFP4-DBG] weight_dual_scale AFTER transform: shape=%s dtype=%s", - weight_dual_scale.shape, - weight_dual_scale.dtype, - ) layer.weight_dual_scale = torch.nn.Parameter( weight_dual_scale, requires_grad=False ) + # Move mul_scale to NPU if present and not already there + mul_scale = layer.mul_scale.data + if not mul_scale.is_npu: + mul_scale = mul_scale.to(f"npu:{torch.npu.current_device()}") + layer.mul_scale = torch.nn.Parameter(mul_scale, requires_grad=False) + def apply_weights( self, layer: torch.nn.Module, @@ -179,6 +160,14 @@ def apply_weights( input_shape = x.shape x_2d = x.reshape(-1, x.shape[-1]) + # Apply smooth quant scale before activation quantization. + # The offline-quantized weights were calibrated under x * mul_scale, + # so we MUST apply it here for scale alignment. + # Reference: MindIE-SD W4A4MXFP4DualQuantLinear.quant_matmul + mul_scale = layer.mul_scale + if not torch.all(mul_scale == 1.0): + x_2d = x_2d * mul_scale.to(x_2d.dtype) + # Dual-level MXFP4 activation quantization x1, l0_scale, l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( x_2d, smooth_scale=None diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 1e395cedc42b..9623719ca9ea 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -55,6 +55,9 @@ # MXFP4 msmodelslim wraps Linear layers with a `.linear.` subpath; # strip it so keys match the SGLang model parameters. ".linear.": ".", + # NonFusionSmoothQuantWrapper exports smooth quant scale as `.div.mul_scale`; + # strip `.div.` so it loads as a direct parameter `mul_scale` on the linear layer. + ".div.": ".", } SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] From 82d16245ac7bbb6acefaa1f899b86eae0d8bd29c Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 17:12:52 +0800 Subject: [PATCH 27/29] :adhesive_bandage: fix(loader): align comment style with junlin branch --- .../multimodal_gen/runtime/loader/transformer_load_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py index 34df39a83bf4..c2d25f042f56 100644 --- a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py @@ -276,7 +276,7 @@ def _resolve_quant_config( resolve quant config from checkpoints' metadata priority: explicit --quantization flag -> model config.json -> safetensors metadata -> format-specific fallback """ - # Explicit --quantization flag takes highest priority + # priority: explicit --quantization flag (e.g. mxfp8, mxfp4, modelslim) if server_args.quantization is not None: from sglang.multimodal_gen.runtime.layers.quantization import ( get_quantization_config, From 28ed2168c561370b3aa36ccbf10be202f5615587 Mon Sep 17 00:00:00 2001 From: Junlin Wu Date: Tue, 12 May 2026 14:51:34 +0800 Subject: [PATCH 28/29] :bug: fix(diffusion/mxfp4): guard torch_npu import; fix CI on non-NPU runners 1. Guard `import torch_npu` with `if _is_npu:` in mxfp4_npu.py and modelslim_mxfp4_scheme.py -- fixes ModuleNotFoundError on all GPU/AMD/MUSA CI runners\n2. Precompute layer.use_mul_scale flag in process_weights_after_loading to avoid GPU-to-CPU sync on every forward pass\n3. Use torch.no_grad() + copy_() instead of .data= for weight shuffle in fp8.py elif _use_aiter: block --- .../layers/quantization/modelslim_mxfp4_scheme.py | 11 +++++++++-- .../runtime/layers/quantization/mxfp4_npu.py | 8 +++++++- python/sglang/srt/layers/quantization/fp8.py | 13 +++++++------ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py index 153d330e31cf..f58b8012f1d0 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py @@ -16,7 +16,13 @@ from typing import List, Optional import torch -import torch_npu + +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_npu = current_platform.is_npu() + +if _is_npu: + import torch_npu from sglang.multimodal_gen.runtime.models.parameter import ( BasevLLMParameter, @@ -143,6 +149,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module): if not mul_scale.is_npu: mul_scale = mul_scale.to(f"npu:{torch.npu.current_device()}") layer.mul_scale = torch.nn.Parameter(mul_scale, requires_grad=False) + layer.use_mul_scale = not torch.all(mul_scale == 1.0).item() def apply_weights( self, @@ -165,7 +172,7 @@ def apply_weights( # so we MUST apply it here for scale alignment. # Reference: MindIE-SD W4A4MXFP4DualQuantLinear.quant_matmul mul_scale = layer.mul_scale - if not torch.all(mul_scale == 1.0): + if getattr(layer, "use_mul_scale", True): x_2d = x_2d * mul_scale.to(x_2d.dtype) # Dual-level MXFP4 activation quantization diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py index cbb4f58488a8..3bdad4d15e82 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -17,9 +17,15 @@ from typing import Any, Dict, List, Optional import torch -import torch_npu from torch.nn.parameter import Parameter +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_npu = current_platform.is_npu() + +if _is_npu: + import torch_npu + from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( QuantizationConfig, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 291e2c2de3bd..4498ced15f3c 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1147,12 +1147,13 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: ) elif _use_aiter: # Pre-shuffle weights - layer.w13_weight.data = shuffle_weight( - layer.w13_weight.contiguous(), (16, 16) - ) - layer.w2_weight.data = shuffle_weight( - layer.w2_weight.contiguous(), (16, 16) - ) + with torch.no_grad(): + layer.w13_weight.copy_( + shuffle_weight(layer.w13_weight.contiguous(), (16, 16)) + ) + layer.w2_weight.copy_( + shuffle_weight(layer.w2_weight.contiguous(), (16, 16)) + ) elif _is_cpu: assert ( _is_cpu_amx_available From 373fc3ffce76ef383dd38849abb2ea20a3010bca Mon Sep 17 00:00:00 2001 From: Junlin Wu Date: Fri, 15 May 2026 20:20:09 +0800 Subject: [PATCH 29/29] :recycle: refactor(diffusion/mxfp4): rename MXFP4Config to NPUMXFP4Config Disambiguate from upstream ROCm Mxfp4Config (mxfp4.py) which differs only by letter case. NPU prefix aligns with LLM-side npu_mxfp4 naming convention. --- .../runtime/layers/quantization/__init__.py | 6 ++++-- .../runtime/layers/quantization/mxfp4_npu.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py index 9be26e62e7ee..1bce2a37470b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -15,7 +15,9 @@ ) from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig from sglang.multimodal_gen.runtime.layers.quantization.mxfp4 import Mxfp4Config -from sglang.multimodal_gen.runtime.layers.quantization.mxfp4_npu import MXFP4Config +from sglang.multimodal_gen.runtime.layers.quantization.mxfp4_npu import ( + NPUMXFP4Config, +) from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config QuantizationMethods = Literal[ @@ -40,7 +42,7 @@ "fp8": Fp8Config, "mxfp4": Mxfp4Config, "mxfp8": MXFP8Config, - "mxfp4_npu": MXFP4Config, + "mxfp4_npu": NPUMXFP4Config, } diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py index 076080f19e61..3798f36b41ad 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp4_npu.py @@ -1,6 +1,6 @@ """Online MXFP4 quantization for Diffusion models on Ascend NPU. -Provides ``MXFP4Config`` (registered as ``"mxfp4_npu"``) and +Provides ``NPUMXFP4Config`` (registered as ``"mxfp4_npu"``) and ``NPUMXFP4DiffusionLinearMethod`` which quantises FP16/BF16 weights to MXFP4 at load time using dual-level MX quantization and uses ``npu_dynamic_dual_level_mx_quant`` + ``npu_dual_level_quant_matmul`` for @@ -40,7 +40,7 @@ logger = init_logger(__name__) -class MXFP4Config(QuantizationConfig): +class NPUMXFP4Config(QuantizationConfig): """Config for online MXFP4 quantization on Ascend NPU (Diffusion).""" def __init__(self) -> None: @@ -63,7 +63,7 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "MXFP4Config": + def from_config(cls, config: Dict[str, Any]) -> "NPUMXFP4Config": return cls() def get_quant_method( @@ -87,7 +87,7 @@ class NPUMXFP4DiffusionLinearMethod(LinearMethodBase): Reference: MindIE-SD ``W4A4MXFP4DualQuantLinear`` (offline path only). """ - def __init__(self, quant_config: MXFP4Config): + def __init__(self, quant_config: NPUMXFP4Config): self.quant_config = quant_config def create_weights(