From ef874c0a1c92bf29a35e7f2e7efaf2bdaed748fa Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 18 Mar 2026 15:53:53 +0800 Subject: [PATCH 01/25] =?UTF-8?q?=E2=9C=A8=20feat(npu):=20add=20online=20M?= =?UTF-8?q?XFP8=20quantization=20support=20for=20Ascend=20NPU=20(Path=20B)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU, supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8 checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant + npu_quant_matmul with group_sizes=[1,1,32]). --- .../npu/quantization/mxfp8_method_npu.py | 152 ++++++++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 6 + .../ascend/test_ascend_mxfp8_quantization.py | 103 ++++++++++++ 3 files changed, 261 insertions(+) create mode 100644 python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py create mode 100644 test/srt/ascend/test_ascend_mxfp8_quantization.py diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py new file mode 100644 index 000000000000..10bb30aa4fab --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py @@ -0,0 +1,152 @@ +from typing import TYPE_CHECKING, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.base_config import LinearMethodBase + +if TYPE_CHECKING: + from sglang.srt.layers.quantization.fp8 import Fp8Config + + +class NPUMXFP8LinearMethod(LinearMethodBase): + """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. + + Supports two modes: + - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 + at weight loading time. + - Offline quantization: loads pre-quantized FP8 weights with block scales + from a serialized checkpoint. + + Uses torch_npu APIs: + - npu_dynamic_mx_quant: dynamic MXFP8 activation quantization (block_size=32) + - npu_quant_matmul: MXFP8 matrix multiplication with group_sizes=[1,1,32] + """ + + MXFP8_BLOCK_SIZE = 32 + + def __init__(self, quant_config: "Fp8Config"): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + is_serialized = self.quant_config.is_checkpoint_fp8_serialized + + # Weight: fp8 if serialized checkpoint, else original dtype (will be + # quantized in process_weights_after_loading) + weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + if is_serialized: + # Block scale: one scale per block of 32 elements along input dim. + # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. + block_k = self.MXFP8_BLOCK_SIZE + scale_cols = (input_size_per_partition + block_k - 1) // block_k + scale = BlockQuantScaleParameter( + data=torch.zeros( + output_size_per_partition, + scale_cols, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale.format_ue8m0 = True + layer.register_parameter("weight_scale_inv", scale) + else: + layer.register_parameter("weight_scale_inv", None) + + def process_weights_after_loading(self, layer: torch.nn.Module): + import torch_npu + + is_serialized = self.quant_config.is_checkpoint_fp8_serialized + + if is_serialized: + # Checkpoint already has fp8 weights + uint8 scales. + # Ensure weight is float8_e4m3fn. + if layer.weight.data.dtype != torch.float8_e4m3fn: + layer.weight = Parameter( + torch_npu.npu_dtype_cast( + layer.weight.data, torch_npu.float8_e4m3fn + ), + requires_grad=False, + ) + else: + layer.weight.requires_grad_(False) + + # Scale is already uint8 (e8m0fnu), keep as-is. + layer.weight_scale_inv.requires_grad_(False) + else: + # Online quantization: quantize FP16/BF16 weights to MXFP8. + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + import torch_npu + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Dynamic MXFP8 activation quantization (block_size=32) + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 quantized matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale_inv.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, self.MXFP8_BLOCK_SIZE], + ) + return output diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 1e12baff13cf..21fa1f97b94d 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -222,6 +222,12 @@ def get_quant_method( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() + if _is_npu and self.use_mxfp8: + from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( + NPUMXFP8LinearMethod, + ) + + return NPUMXFP8LinearMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( diff --git a/test/srt/ascend/test_ascend_mxfp8_quantization.py b/test/srt/ascend/test_ascend_mxfp8_quantization.py new file mode 100644 index 000000000000..e7af4ff97f1a --- /dev/null +++ b/test/srt/ascend/test_ascend_mxfp8_quantization.py @@ -0,0 +1,103 @@ +""" +Usage: +python3 -m unittest test_ascend_mxfp8_quantization.TestAscendMXFP8.test_gsm8k +""" + +import os +import time +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + +if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 +) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" + + +class TestAscendMXFP8(CustomTestCase): + """Test online MXFP8 quantization (--quantization mxfp8) on Ascend NPU.""" + + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen2.5-0.5B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--quantization", + "mxfp8", + "--device", + "npu", + "--attention-backend", + "ascend", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + url = urlparse(self.base_url) + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + metrics = run_eval(args) + print(metrics) + self.assertGreaterEqual(metrics["accuracy"], 0.25) + self.assertGreaterEqual(metrics["output_throughput"], 500) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.perf_counter() + res = self.run_decode(max_tokens) + tok = time.perf_counter() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + + if is_in_ci(): + self.assertGreaterEqual(throughput, 20) + + +if __name__ == "__main__": + unittest.main() From d2d19c6f12fde51e6903f206cf493f6bdb7dea55 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 18 Mar 2026 16:00:42 +0800 Subject: [PATCH 02/25] =?UTF-8?q?=E2=9C=A8=20feat(diffusion):=20add=20onli?= =?UTF-8?q?ne=20MXFP8=20quantization=20support=20for=20Wan2.2=20on=20Ascen?= =?UTF-8?q?d=20NPU?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2 and other diffusion models on Ascend NPU. Also adds explicit quantization field to diffusion ServerArgs so online quantization can be specified without pre-quantized weights. --- .../runtime/layers/quantization/__init__.py | 4 +- .../runtime/layers/quantization/mxfp8_npu.py | 153 ++++++++++++++++++ .../component_loaders/transformer_loader.py | 11 +- .../multimodal_gen/runtime/server_args.py | 4 + 4 files changed, 170 insertions(+), 2 deletions(-) create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py index 3d78bb58cd9e..625e50af8115 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -7,8 +7,9 @@ ) from sglang.multimodal_gen.runtime.layers.quantization.fp8 import Fp8Config from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig +from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config -QuantizationMethods = Literal["fp8", "modelslim"] +QuantizationMethods = Literal["fp8", "modelslim", "mxfp8"] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -16,6 +17,7 @@ _CUSTOMIZED_METHOD_TO_QUANT_CONFIG = { "modelslim": ModelSlimConfig, "fp8": Fp8Config, + "mxfp8": MXFP8Config, } diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py new file mode 100644 index 000000000000..6984a59d192e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -0,0 +1,153 @@ +"""Online MXFP8 quantization for Diffusion models on Ascend NPU. + +Provides ``MXFP8Config`` (registered as ``"mxfp8"``) and +``NPUMXFP8DiffusionLinearMethod`` which quantise FP16/BF16 weights to MXFP8 +at load time and use ``npu_dynamic_mx_quant`` + ``npu_quant_matmul`` for +inference, mirroring the LLM-side ``NPUMXFP8LinearMethod``. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.layers.linear import LinearMethodBase +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter + +logger = logging.getLogger(__name__) + +MXFP8_BLOCK_SIZE = 32 + + +class MXFP8Config(QuantizationConfig): + """Config for online MXFP8 quantization on Ascend NPU (Diffusion).""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> str: + return "mxfp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU, not CUDA + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config": + return cls() + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + from sglang.multimodal_gen.runtime.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + return NPUMXFP8DiffusionLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class NPUMXFP8DiffusionLinearMethod(LinearMethodBase): + """Ascend NPU MXFP8 linear method for Diffusion models. + + Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. + Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). + """ + + def __init__(self, quant_config: MXFP8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + import torch_npu + + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Online MXFP8 quantisation of weights (block_size=32) + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + import torch_npu + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale_inv.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + return output diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py index 658689ec2c98..13c11bb09c88 100644 --- a/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py +++ b/python/sglang/multimodal_gen/runtime/loader/component_loaders/transformer_loader.py @@ -82,7 +82,16 @@ def _resolve_quant_config( safetensors_list: list[str], component_model_path: str, ) -> Optional[dict]: - # priority: model config.json → safetensors metadata → nunchaku config + # priority: explicit --quantization flag → model config.json + # → safetensors metadata → nunchaku config + if server_args.quantization is not None: + from sglang.multimodal_gen.runtime.layers.quantization import ( + get_quantization_config, + ) + + quant_cls = get_quantization_config(server_args.quantization) + return quant_cls.from_config({}) + quant_config = get_quant_config(hf_config, component_model_path) if quant_config is None and server_args.transformer_weights_path: # try to read quantization_config from the safetensors metadata header diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index acd41d59a7a4..72c2cd5f63e9 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -167,6 +167,10 @@ class ServerArgs: disable_autocast: bool | None = None + # Explicit quantization method override (e.g. "mxfp8", "fp8", "modelslim"). + # When set, the transformer loader will use this instead of auto-detection. + quantization: str | None = None + # Quantization / Nunchaku SVDQuant configuration nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field( default_factory=NunchakuSVDQuantArgs, repr=False From c838adef02cb6f830f0073b557b74fc239d390fc Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 19 Mar 2026 09:47:48 +0800 Subject: [PATCH 03/25] :bug: fix(diffusion): fix npu method call error --- .../runtime/layers/quantization/mxfp8_npu.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index 6984a59d192e..fbeed2a03a4c 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -113,6 +113,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) + # Ensure weight is on NPU before calling npu_dynamic_mx_quant + if not weight_fp.is_npu: + weight_fp = weight_fp.npu() + # Online MXFP8 quantisation of weights (block_size=32) qw, w_scale = torch_npu.npu_dynamic_mx_quant( weight_fp, dst_type=torch_npu.float8_e4m3fn @@ -133,6 +137,13 @@ def apply( x = x.to(torch.bfloat16) original_dtype = torch.bfloat16 + # npu_quant_matmul requires 3D input; flatten leading dims if needed + input_shape = x.shape + if x.dim() > 3: + x = x.reshape(-1, x.shape[-2], x.shape[-1]) + elif x.dim() == 2: + x = x.unsqueeze(0) + # Dynamic MXFP8 activation quantisation qx, input_scale = torch_npu.npu_dynamic_mx_quant( x, dst_type=torch_npu.float8_e4m3fn @@ -150,4 +161,10 @@ def apply( output_dtype=original_dtype, group_sizes=[1, 1, MXFP8_BLOCK_SIZE], ) + + # Restore original shape (replace last dim with output features) + if len(input_shape) != 3: + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + return output From be3b684c10a72213e33f2bc50ced54dfc26eea1b Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 19 Mar 2026 10:32:11 +0800 Subject: [PATCH 04/25] :bug: fix(diffusion): fix MXFP8 quantization scale dimension mismatch on NPU - Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call - Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul) - Simplify output shape restoration logic Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4 --- .../runtime/layers/quantization/mxfp8_npu.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index fbeed2a03a4c..7e7cc37028a9 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -137,16 +137,13 @@ def apply( x = x.to(torch.bfloat16) original_dtype = torch.bfloat16 - # npu_quant_matmul requires 3D input; flatten leading dims if needed + # Flatten to 2D [tokens, hidden] so npu_dynamic_mx_quant returns 3D scale input_shape = x.shape - if x.dim() > 3: - x = x.reshape(-1, x.shape[-2], x.shape[-1]) - elif x.dim() == 2: - x = x.unsqueeze(0) + x_2d = x.reshape(-1, x.shape[-1]) # Dynamic MXFP8 activation quantisation qx, input_scale = torch_npu.npu_dynamic_mx_quant( - x, dst_type=torch_npu.float8_e4m3fn + x_2d, dst_type=torch_npu.float8_e4m3fn ) # MXFP8 matmul @@ -163,8 +160,7 @@ def apply( ) # Restore original shape (replace last dim with output features) - if len(input_shape) != 3: - output_shape = list(input_shape[:-1]) + [output.shape[-1]] - output = output.reshape(output_shape) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) return output From fd79b235af48200c40f9469ab6d54348880e8637 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 20 Mar 2026 09:38:27 +0800 Subject: [PATCH 05/25] :recycle: refactor(mxfp8): split linear method into config and NPU layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 按 reviewer 建议重构架构分层: - 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights) - 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用 - 改进架构分层,符合现有 NPU INT8 方法模式 --- .../npu/quantization/mxfp8_method_npu.py | 93 +++-------------- python/sglang/srt/layers/quantization/fp8.py | 99 ++++++++++++++++++- 2 files changed, 107 insertions(+), 85 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py index 10bb30aa4fab..97f150040553 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py @@ -1,97 +1,27 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import Optional import torch from torch.nn.parameter import Parameter -from sglang.srt.layers.parameter import ( - BlockQuantScaleParameter, - ModelWeightParameter, -) -from sglang.srt.layers.quantization.base_config import LinearMethodBase -if TYPE_CHECKING: - from sglang.srt.layers.quantization.fp8 import Fp8Config +class NPUMXFP8LinearMethod: + """Ascend NPU MXFP8 weight processing and kernel calls. + This class handles NPU-specific operations: + - process_weights_after_loading: dtype casting and online quantization + - apply: dynamic activation quantization + MXFP8 matmul -class NPUMXFP8LinearMethod(LinearMethodBase): - """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. - - Supports two modes: - - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 - at weight loading time. - - Offline quantization: loads pre-quantized FP8 weights with block scales - from a serialized checkpoint. - - Uses torch_npu APIs: - - npu_dynamic_mx_quant: dynamic MXFP8 activation quantization (block_size=32) - - npu_quant_matmul: MXFP8 matrix multiplication with group_sizes=[1,1,32] + Weight creation and config management are handled by + MXFP8LinearAscendMethod in fp8.py. """ MXFP8_BLOCK_SIZE = 32 - def __init__(self, quant_config: "Fp8Config"): - self.quant_config = quant_config - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, + def process_weights_after_loading( + self, layer: torch.nn.Module, is_serialized: bool ): - output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") - - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - - is_serialized = self.quant_config.is_checkpoint_fp8_serialized - - # Weight: fp8 if serialized checkpoint, else original dtype (will be - # quantized in process_weights_after_loading) - weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - if is_serialized: - # Block scale: one scale per block of 32 elements along input dim. - # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. - block_k = self.MXFP8_BLOCK_SIZE - scale_cols = (input_size_per_partition + block_k - 1) // block_k - scale = BlockQuantScaleParameter( - data=torch.zeros( - output_size_per_partition, - scale_cols, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - scale.format_ue8m0 = True - layer.register_parameter("weight_scale_inv", scale) - else: - layer.register_parameter("weight_scale_inv", None) - - def process_weights_after_loading(self, layer: torch.nn.Module): import torch_npu - is_serialized = self.quant_config.is_checkpoint_fp8_serialized - if is_serialized: # Checkpoint already has fp8 weights + uint8 scales. # Ensure weight is float8_e4m3fn. @@ -113,6 +43,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) + if not weight_fp.is_npu: + weight_fp = weight_fp.npu() + qw, w_scale = torch_npu.npu_dynamic_mx_quant( weight_fp, dst_type=torch_npu.float8_e4m3fn ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 21fa1f97b94d..8d6fdba452bd 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -223,11 +223,7 @@ def get_quant_method( ): return UnquantizedLinearMethod() if _is_npu and self.use_mxfp8: - from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( - NPUMXFP8LinearMethod, - ) - - return NPUMXFP8LinearMethod(self) + return MXFP8LinearAscendMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( @@ -245,6 +241,99 @@ def get_scaled_act_names(self) -> List[str]: return [] +class MXFP8LinearAscendMethod(LinearMethodBase): + """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. + + Supports two modes: + - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 + at weight loading time. + - Offline quantization: loads pre-quantized FP8 weights with block scales + from a serialized checkpoint. + + Weight creation is handled here; weight processing and kernel calls are + delegated to NPUMXFP8LinearMethod in the NPU hardware backend. + """ + + MXFP8_BLOCK_SIZE = 32 + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( + NPUMXFP8LinearMethod + ) + + self.npu_method = NPUMXFP8LinearMethod() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + is_serialized = self.quant_config.is_checkpoint_fp8_serialized + + # Weight: fp8 if serialized checkpoint, else original dtype (will be + # quantized in process_weights_after_loading) + weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + if is_serialized: + # Block scale: one scale per block of 32 elements along input dim. + # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. + block_k = self.MXFP8_BLOCK_SIZE + scale_cols = (input_size_per_partition + block_k - 1) // block_k + scale = BlockQuantScaleParameter( + data=torch.zeros( + output_size_per_partition, + scale_cols, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale.format_ue8m0 = True + layer.register_parameter("weight_scale_inv", scale) + else: + layer.register_parameter("weight_scale_inv", None) + + def process_weights_after_loading(self, layer: torch.nn.Module): + self.npu_method.process_weights_after_loading( + layer, self.quant_config.is_checkpoint_fp8_serialized + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.npu_method.apply(layer, x, bias) + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. From 490ad0b11897062d9731b8b160ab6dbfd441c964 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 20 Mar 2026 11:43:07 +0800 Subject: [PATCH 06/25] :sparkles: feat(diffusion): add offline MXFP8 pre-quantized weight support for Wan2.2 TI2V --- .../runtime/layers/quantization/modelslim.py | 8 ++ .../quantization/modelslim_mxfp8_scheme.py | 117 ++++++++++++++++++ .../runtime/loader/fsdp_load.py | 1 + .../runtime/utils/quantization_utils.py | 7 +- 4 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index afb9a31e4db9..4d43c50e3675 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -119,6 +119,14 @@ def _get_scheme_from_parts( return ModelSlimW4A4Int4( quant_config=self.quant_description, prefix=layer_name ) + elif quant_type == "W8A8_MXFP8": + from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp8_scheme import ( + ModelSlimMXFP8Scheme, + ) + + return ModelSlimMXFP8Scheme( + quant_config=self.quant_description, prefix=layer_name + ) raise NotImplementedError("No modelslim compatible scheme was found.") def get_scheme( diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py new file mode 100644 index 000000000000..eb40a59fa0ef --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -0,0 +1,117 @@ +"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. + +Loads weights pre-quantized by msmodelslim (int8 storage for float8_e4m3fn, +uint8 storage for float8_e8m0fnu scales) and runs MXFP8 matmul at inference. +""" + +from typing import Dict, List, Optional + +import torch + +from sglang.multimodal_gen.runtime.models.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP8_BLOCK_SIZE = 32 + + +class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): + + def __init__(self, quant_config: Dict[str, any], prefix: str): + self.quant_config = quant_config + self.prefix = prefix + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as int8 (storage for float8_e4m3fn) + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.int8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8 (storage for float8_e8m0fnu) + # shape: [out, in/32 * 2] + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE * 2 + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + import torch_npu + + # Cast int8 → float8_e4m3fn + weight = layer.weight.data + weight = torch_npu.npu_dtype_cast(weight, torch_npu.float8_e4m3fn) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + + # Reshape weight_scale: [out, in/32*2] → [out, in/32, 2] + weight_scale = layer.weight_scale.data + weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + import torch_npu + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D for npu_dynamic_mx_quant + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output diff --git a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py index dcfbb6eac76a..fab7dd96df3f 100644 --- a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py +++ b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -378,6 +378,7 @@ def load_model_from_full_model_state_dict( "bias", "norm_q", "norm_k", + "weight_scale", ] for new_param_name in unused_keys: if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS): diff --git a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py index e1489780d4f7..c235c0523624 100644 --- a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py @@ -16,9 +16,14 @@ def find_quant_modelslim_config(model_config, component_model_path): + # Try exact name first, then glob for variant filenames (e.g. after repack) quant_config_file = Path(component_model_path, "quant_model_description.json") + if not quant_config_file.is_file(): + candidates = sorted(Path(component_model_path).glob("quant_model_description*.json")) + quant_config_file = candidates[0] if candidates else None + quant_cfg = None - if quant_config_file.is_file(): + if quant_config_file is not None and Path(quant_config_file).is_file(): with open(quant_config_file) as f: quant_cfg = json.load(f) # This field is required for flagless model loading but is not present in From cc80690cfdee8656ed9e4e388017a85418ffcaf3 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Mon, 23 Mar 2026 12:04:22 +0800 Subject: [PATCH 07/25] :bug: fix(diffusion): correct MXFP8 weight dtype and scale shape Fix weight loading for msmodelslim pre-quantized MXFP8 weights: - Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors) - Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export) - Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2] --- .../quantization/modelslim_mxfp8_scheme.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py index eb40a59fa0ef..4485d251fdf6 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -1,7 +1,7 @@ """ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. -Loads weights pre-quantized by msmodelslim (int8 storage for float8_e4m3fn, -uint8 storage for float8_e8m0fnu scales) and runs MXFP8 matmul at inference. +Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, +uint8 scales) and runs MXFP8 matmul at inference. """ from typing import Dict, List, Optional @@ -36,11 +36,11 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) - # msmodelslim exports weight as int8 (storage for float8_e4m3fn) + # msmodelslim exports weight as float8_e4m3fn, shape [out, in] weight = ModelWeightParameter( data=torch.empty( (output_size_per_partition, input_size_per_partition), - dtype=torch.int8, + dtype=torch.float8_e4m3fn, ), input_dim=1, output_dim=0, @@ -48,9 +48,8 @@ def create_weights( ) layer.register_parameter("weight", weight) - # msmodelslim exports weight_scale as uint8 (storage for float8_e8m0fnu) - # shape: [out, in/32 * 2] - scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE * 2 + # msmodelslim exports weight_scale as uint8, shape [out, in/32] + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE weight_scale = GroupQuantScaleParameter( data=torch.empty( (output_size_per_partition, scale_dim), @@ -63,14 +62,11 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module): - import torch_npu - - # Cast int8 → float8_e4m3fn + # weight is already float8_e4m3fn, no cast needed weight = layer.weight.data - weight = torch_npu.npu_dtype_cast(weight, torch_npu.float8_e4m3fn) layer.weight = torch.nn.Parameter(weight, requires_grad=False) - # Reshape weight_scale: [out, in/32*2] → [out, in/32, 2] + # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2] weight_scale = layer.weight_scale.data weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) From b9aa78553f9b65c878c1e873757ea0899a5a17a3 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 24 Mar 2026 11:25:25 +0800 Subject: [PATCH 08/25] =?UTF-8?q?=E2=9C=A8=20feat(wan22):=20Redesigned=20t?= =?UTF-8?q?he=20wan=5Frepack=20tool.=20Now=20support=20one-click=20weight?= =?UTF-8?q?=20processing.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sglang/multimodal_gen/tools/wan_repack.py | 330 ++++++++++++------ 1 file changed, 215 insertions(+), 115 deletions(-) diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 2d7132747e7a..3fd23a10cbd0 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -1,115 +1,215 @@ -### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py - -import argparse -import json -import pathlib -from typing import Any, Dict, Tuple - -from safetensors.torch import load_file, save_file - -TRANSFORMER_KEYS_RENAME_DICT = { - "time_embedding.0": "condition_embedder.time_embedder.linear_1", - "time_embedding.2": "condition_embedder.time_embedder.linear_2", - "text_embedding.0": "condition_embedder.text_embedder.linear_1", - "text_embedding.2": "condition_embedder.text_embedder.linear_2", - "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", - "head.head": "proj_out", - "modulation": "scale_shift_table", - "ffn.0": "ffn.net.0.proj", - "ffn.2": "ffn.net.2", - # Hack to swap the layer names - # The original model calls the norms in following order: norm1, norm3, norm2 - # We convert it to: norm1, norm2, norm3 - "norm2": "norm__placeholder", - "norm3": "norm2", - "norm__placeholder": "norm3", - # For the I2V model - "img_emb.proj.0": "condition_embedder.image_embedder.norm1", - "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", - "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", - "img_emb.proj.4": "condition_embedder.image_embedder.norm2", - # for the FLF2V model - "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", - # Add attention component mappings - "self_attn.q": "attn1.to_q", - "self_attn.k": "attn1.to_k", - "self_attn.v": "attn1.to_v", - "self_attn.o": "attn1.to_out.0", - "self_attn.norm_q": "attn1.norm_q", - "self_attn.norm_k": "attn1.norm_k", - "cross_attn.q": "attn2.to_q", - "cross_attn.k": "attn2.to_k", - "cross_attn.v": "attn2.to_v", - "cross_attn.o": "attn2.to_out.0", - "cross_attn.norm_q": "attn2.norm_q", - "cross_attn.norm_k": "attn2.norm_k", - "attn2.to_k_img": "attn2.add_k_proj", - "attn2.to_v_img": "attn2.add_v_proj", - "attn2.norm_k_img": "attn2.norm_added_k", -} - - -def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: - if model_type == "Wan-T2V-14B": - RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT - return RENAME_DICT - - -def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: - dict[new_key] = dict.pop(old_key) - - -def load_sharded_safetensors(path: pathlib.Path): - file_path = path - state_dict = {} - state_dict.update(load_file(file_path)) - return state_dict - - -def convert_transformer(model_type: str, model_dir: str, output_dir: str): - pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - RENAME_DICT = get_transformer_config(model_type) - - original_state_dict = load_sharded_safetensors( - pathlib.Path(model_dir, "*model*.safetensors") - ) - with open(pathlib.Path(model_dir, "*quant_model_description*.json")) as f: - original_quant_config = json.load(f) - - for key in list(original_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_dict_(original_state_dict, key, new_key) - update_dict_(original_quant_config, key, new_key) - - save_file( - original_state_dict, - pathlib.Path(output_dir, "diffusion_pytorch_model.safetensors"), - ) - - with open(pathlib.Path(output_dir, "quant_model_description.json"), "w") as f: - json.dump(original_quant_config, f) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--input-path", type=str, required=True) - parser.add_argument("--output-path", type=str, required=True) - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "high_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer"), - ) - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "low_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer_2"), - ) +### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py + +import argparse +import json +import pathlib +import shutil +from typing import Any, Dict, List + +from safetensors.torch import load_file, save_file + +TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # for the FLF2V model + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", +} + +SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] + +# Cascade models have two transformers (high_noise + low_noise) +CASCADE_MODEL_TYPES = {"Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B"} + + +def get_transformer_config(model_type: str) -> Dict[str, Any]: + if model_type in SUPPORTED_MODEL_TYPES: + return TRANSFORMER_KEYS_RENAME_DICT + else: + raise ValueError( + f"Unsupported model_type: {model_type}. Supported: {SUPPORTED_MODEL_TYPES}" + ) + + +def get_transformer_dirs(model_type: str) -> List[str]: + """Return the list of transformer directory names for a given model type.""" + if model_type in CASCADE_MODEL_TYPES: + return ["transformer", "transformer_2"] + return ["transformer"] + + +def get_quant_subpath( + model_type: str, quant_path: pathlib.Path, transformer_dir: str +) -> pathlib.Path: + """Return the quant weights subdirectory for a given transformer.""" + if model_type in CASCADE_MODEL_TYPES: + sub = "high_noise_model" if transformer_dir == "transformer" else "low_noise_model" + return quant_path / sub + return quant_path + + +def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> None: + dict[new_key] = dict.pop(old_key) + + +def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: + candidates = sorted(directory.glob(pattern)) + if not candidates: + raise FileNotFoundError(f"No file matching '{pattern}' found in {directory}") + if len(candidates) > 1: + raise FileNotFoundError( + f"Multiple files matching '{pattern}' found in {directory}: {candidates}" + ) + print(f" Loading: {candidates[0].name}") + state_dict = {} + state_dict.update(load_file(candidates[0])) + return state_dict + + +def convert_transformer( + model_type: str, model_dir: pathlib.Path, output_dir: pathlib.Path +) -> None: + """Convert a single quantized transformer directory into Diffusers format.""" + model_path = pathlib.Path(model_dir) + out_path = pathlib.Path(output_dir) + out_path.mkdir(parents=True, exist_ok=True) + RENAME_DICT = get_transformer_config(model_type) + + state_dict = load_sharded_safetensors(model_path, "quant_model_weight*.safetensors") + + json_candidates = sorted(model_path.glob("quant_model_description*.json")) + if not json_candidates: + raise FileNotFoundError( + f"No quant_model_description*.json found in {model_path}" + ) + with open(json_candidates[0]) as f: + quant_config = json.load(f) + + for key in list(state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + if new_key != key: + update_dict_(state_dict, key, new_key) + # The quant JSON only covers quantized layers, not all model keys + if key in quant_config: + update_dict_(quant_config, key, new_key) + + save_file(state_dict, out_path / "diffusion_pytorch_model.safetensors") + + with open(out_path / "quant_model_description.json", "w") as f: + json.dump(quant_config, f, indent=2) + + +def repack( + model_type: str, + original_model_path: pathlib.Path, + quant_path: pathlib.Path, + output_path: pathlib.Path, +) -> None: + """ + Full one-step repack workflow: + 1. Copy the original HF Diffusers model to output_path, excluding transformer dir(s). + 2. For each transformer: convert quant weights and copy config.json from original. + """ + transformer_dirs = get_transformer_dirs(model_type) + + # Step 1: Copy original model, skipping transformer dirs (they will be replaced) + print(f"Step 1: Copying original model to {output_path}") + print(f" (skipping: {transformer_dirs})") + shutil.copytree( + str(original_model_path), + str(output_path), + ignore=shutil.ignore_patterns(*transformer_dirs), + ) + + # Step 2+: Convert each transformer + for i, tdir in enumerate(transformer_dirs): + q_path = get_quant_subpath(model_type, quant_path, tdir) + out_tdir = output_path / tdir + print(f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})...") + convert_transformer(model_type, q_path, out_tdir) + + # Copy config.json from the original transformer dir + src_config = original_model_path / tdir / "config.json" + if src_config.is_file(): + shutil.copy2(str(src_config), str(out_tdir / "config.json")) + print(f" Copied config.json from original {tdir}/") + + print(f"\nDone! Repacked model saved to: {output_path}") + + +def get_args(): + parser = argparse.ArgumentParser( + description="Repack msmodelslim quantized Wan2.2 weights into HF Diffusers format" + ) + parser.add_argument( + "--model-type", + type=str, + required=True, + choices=SUPPORTED_MODEL_TYPES, + help="Model type to convert", + ) + parser.add_argument( + "--original-model-path", + type=str, + required=True, + help="Path to the original HF Diffusers model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers)", + ) + parser.add_argument( + "--quant-path", + type=str, + required=True, + help="Path to msmodelslim quantized weights directory", + ) + parser.add_argument( + "--output-path", + type=str, + required=True, + help="Output path for the repacked model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers-MXFP8)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + repack( + model_type=args.model_type, + original_model_path=pathlib.Path(args.original_model_path), + quant_path=pathlib.Path(args.quant_path), + output_path=pathlib.Path(args.output_path), + ) From 22bee9e9c148d89e3c6794b9652387ea4b46893c Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 24 Mar 2026 18:15:52 +0800 Subject: [PATCH 09/25] :recycle: refactor(mxfp8): hoist imports and replace print with logger --- .../quantization/modelslim_mxfp8_scheme.py | 9 +++++-- .../runtime/layers/quantization/mxfp8_npu.py | 15 ++++-------- .../runtime/utils/quantization_utils.py | 4 +++- .../sglang/multimodal_gen/tools/wan_repack.py | 24 +++++++++++++------ .../npu/quantization/mxfp8_method_npu.py | 5 +--- python/sglang/srt/layers/quantization/fp8.py | 2 +- 6 files changed, 34 insertions(+), 25 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py index 4485d251fdf6..a5c92f039e94 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional import torch +import torch_npu from sglang.multimodal_gen.runtime.models.parameter import ( GroupQuantScaleParameter, @@ -77,14 +78,18 @@ def apply_weights( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - import torch_npu original_dtype = x.dtype if original_dtype not in (torch.float16, torch.bfloat16): + # npu_dynamic_mx_quant only accepts fp16/bf16 activations x = x.to(torch.bfloat16) original_dtype = torch.bfloat16 - # Flatten to 2D for npu_dynamic_mx_quant + # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size]. + # Diffusion transformer inputs are typically 3D [batch, seq, hidden] or + # higher. Flattening to 2D merges all leading dimensions into a single + # token axis so the NPU kernel can compute per-token MXFP8 scales, then + # we restore the original shape from the output. input_shape = x.shape x_2d = x.reshape(-1, x.shape[-1]) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index 7e7cc37028a9..1d4d88adb8ec 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -8,20 +8,21 @@ from __future__ import annotations -import logging from typing import Any, Dict, List, Optional import torch +import torch_npu from torch.nn.parameter import Parameter -from sglang.multimodal_gen.runtime.layers.linear import LinearMethodBase +from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger -logger = logging.getLogger(__name__) +logger = init_logger(__name__) MXFP8_BLOCK_SIZE = 32 @@ -55,8 +56,6 @@ def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: - from sglang.multimodal_gen.runtime.layers.linear import LinearBase - if isinstance(layer, LinearBase): return NPUMXFP8DiffusionLinearMethod(self) return None @@ -107,15 +106,13 @@ def create_weights( layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - import torch_npu weight_fp = layer.weight.data if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) # Ensure weight is on NPU before calling npu_dynamic_mx_quant - if not weight_fp.is_npu: - weight_fp = weight_fp.npu() + assert weight_fp.is_npu # Online MXFP8 quantisation of weights (block_size=32) qw, w_scale = torch_npu.npu_dynamic_mx_quant( @@ -130,8 +127,6 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - import torch_npu - original_dtype = x.dtype if original_dtype not in (torch.float16, torch.bfloat16): x = x.to(torch.bfloat16) diff --git a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py index c235c0523624..c2380a409ce2 100644 --- a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py @@ -19,7 +19,9 @@ def find_quant_modelslim_config(model_config, component_model_path): # Try exact name first, then glob for variant filenames (e.g. after repack) quant_config_file = Path(component_model_path, "quant_model_description.json") if not quant_config_file.is_file(): - candidates = sorted(Path(component_model_path).glob("quant_model_description*.json")) + candidates = sorted( + Path(component_model_path).glob("quant_model_description*.json") + ) quant_config_file = candidates[0] if candidates else None quant_cfg = None diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 3fd23a10cbd0..d7951e6c3c5f 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -8,6 +8,10 @@ from safetensors.torch import load_file, save_file +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + TRANSFORMER_KEYS_RENAME_DICT = { "time_embedding.0": "condition_embedder.time_embedder.linear_1", "time_embedding.2": "condition_embedder.time_embedder.linear_2", @@ -77,7 +81,11 @@ def get_quant_subpath( ) -> pathlib.Path: """Return the quant weights subdirectory for a given transformer.""" if model_type in CASCADE_MODEL_TYPES: - sub = "high_noise_model" if transformer_dir == "transformer" else "low_noise_model" + sub = ( + "high_noise_model" + if transformer_dir == "transformer" + else "low_noise_model" + ) return quant_path / sub return quant_path @@ -94,7 +102,7 @@ def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: raise FileNotFoundError( f"Multiple files matching '{pattern}' found in {directory}: {candidates}" ) - print(f" Loading: {candidates[0].name}") + state_dict = {} state_dict.update(load_file(candidates[0])) return state_dict @@ -149,8 +157,8 @@ def repack( transformer_dirs = get_transformer_dirs(model_type) # Step 1: Copy original model, skipping transformer dirs (they will be replaced) - print(f"Step 1: Copying original model to {output_path}") - print(f" (skipping: {transformer_dirs})") + logger.debug(f"Step 1: Copying original model to {output_path}") + logger.debug(f" (skipping: {transformer_dirs})") shutil.copytree( str(original_model_path), str(output_path), @@ -161,16 +169,18 @@ def repack( for i, tdir in enumerate(transformer_dirs): q_path = get_quant_subpath(model_type, quant_path, tdir) out_tdir = output_path / tdir - print(f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})...") + logger.debug( + f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})..." + ) convert_transformer(model_type, q_path, out_tdir) # Copy config.json from the original transformer dir src_config = original_model_path / tdir / "config.json" if src_config.is_file(): shutil.copy2(str(src_config), str(out_tdir / "config.json")) - print(f" Copied config.json from original {tdir}/") + logger.debug(f" Copied config.json from original {tdir}/") - print(f"\nDone! Repacked model saved to: {output_path}") + logger.info(f"\nDone! Repacked model saved to: {output_path}") def get_args(): diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py index 97f150040553..c83b77cad7f4 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py @@ -1,6 +1,7 @@ from typing import Optional import torch +import torch_npu from torch.nn.parameter import Parameter @@ -20,8 +21,6 @@ class NPUMXFP8LinearMethod: def process_weights_after_loading( self, layer: torch.nn.Module, is_serialized: bool ): - import torch_npu - if is_serialized: # Checkpoint already has fp8 weights + uint8 scales. # Ensure weight is float8_e4m3fn. @@ -58,8 +57,6 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - import torch_npu - original_dtype = x.dtype if original_dtype not in (torch.float16, torch.bfloat16): x = x.to(torch.bfloat16) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 9a841909ae8d..ef8c3e02619e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -260,7 +260,7 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( - NPUMXFP8LinearMethod + NPUMXFP8LinearMethod, ) self.npu_method = NPUMXFP8LinearMethod() From a29bb3d6ad91d0f2f856e644989398cb36769c23 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 09:21:20 +0800 Subject: [PATCH 10/25] :pencil2: fix(diffusion/mxfp8): address review comments on ModelSlimMXFP8Scheme - Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode) - Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format - Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape - Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv) - Improve flatten-to-2D comment to explain NPU kernel requirement --- .../runtime/layers/quantization/modelslim.py | 4 +--- .../layers/quantization/modelslim_mxfp8_scheme.py | 12 ++++++------ python/sglang/multimodal_gen/tools/wan_repack.py | 4 ++-- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index 4d43c50e3675..4a9b96f9c9c9 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -124,9 +124,7 @@ def _get_scheme_from_parts( ModelSlimMXFP8Scheme, ) - return ModelSlimMXFP8Scheme( - quant_config=self.quant_description, prefix=layer_name - ) + return ModelSlimMXFP8Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_scheme( diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py index a5c92f039e94..c12464d691f7 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -4,7 +4,7 @@ uint8 scales) and runs MXFP8 matmul at inference. """ -from typing import Dict, List, Optional +from typing import List, Optional import torch import torch_npu @@ -20,10 +20,6 @@ class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): - def __init__(self, quant_config: Dict[str, any], prefix: str): - self.quant_config = quant_config - self.prefix = prefix - def create_weights( self, layer: torch.nn.Module, @@ -49,7 +45,11 @@ def create_weights( ) layer.register_parameter("weight", weight) - # msmodelslim exports weight_scale as uint8, shape [out, in/32] + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # NOTE: This parameter is intentionally named "weight_scale" (not + # "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader + # matches parameter names to checkpoint keys, and msmodelslim checkpoints + # store this tensor under the key ".weight_scale". scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE weight_scale = GroupQuantScaleParameter( data=torch.empty( diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index d7951e6c3c5f..308b229d8593 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -90,8 +90,8 @@ def get_quant_subpath( return quant_path -def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> None: - dict[new_key] = dict.pop(old_key) +def update_dict_(d: Dict[str, Any], old_key: str, new_key: str) -> None: + d[new_key] = d.pop(old_key) def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: From 250fe65deda8476402ad002e43feb4b6a6e184f8 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 10:39:35 +0800 Subject: [PATCH 11/25] :adhesive_bandage: fix(diffusion): register --quantization CLI arg to avoid argparse ambiguity --- python/sglang/multimodal_gen/runtime/server_args.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 72c2cd5f63e9..60565998c005 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -727,6 +727,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Disable autocast for denoising loop and vae decoding in pipeline sampling", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help='Quantization method override (e.g. "mxfp8", "fp8", "modelslim"). ' + "When set, the transformer loader will use this instead of auto-detection.", + ) + # Nunchaku SVDQuant quantization parameters NunchakuSVDQuantArgs.add_cli_args(parser) From e146b03179e8d4215138b3f5aeb032083f74d8e7 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 11:02:44 +0800 Subject: [PATCH 12/25] :bug: fix(mxfp8_npu): move weight to current NPU device before quantization --- .../runtime/layers/quantization/mxfp8_npu.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py index 1d4d88adb8ec..b3dd4612460d 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -111,8 +111,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) - # Ensure weight is on NPU before calling npu_dynamic_mx_quant - assert weight_fp.is_npu + # Move weight to NPU if needed. We intentionally use a conditional + # move rather than an assert because `dit_cpu_offload` defaults to + # True in ServerArgs, which causes fsdp_load to move every parameter + # back to CPU after loading (even when the target device is NPU). + # npu_dynamic_mx_quant requires an NPU tensor, so we must transfer + # here. The quantized fp8 weights produced below will remain on NPU + # for inference; if the model still needs to be offloaded after + # quantization (e.g. very large model on a small NPU), a higher-level + # offload pass can move them back afterwards. + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") # Online MXFP8 quantisation of weights (block_size=32) qw, w_scale = torch_npu.npu_dynamic_mx_quant( From 711bb8b8ec0227c150ceff643b0ddecd6cc5c4be Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 25 Mar 2026 14:34:55 +0800 Subject: [PATCH 13/25] :rewind: revert(llm): remove LLM MXFP8 online quantization (Path B) for separate PR Revert LLM-side MXFP8 changes to split into a separate PR. This branch now only contains Wan2.2 Diffusion MXFP8 changes. Reverted files: - fp8.py: removed MXFP8LinearAscendMethod class and NPU branch - mxfp8_method_npu.py: deleted (NPU MXFP8 linear method) - test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test) LLM MXFP8 code preserved on junlin_llm branch. --- .../npu/quantization/mxfp8_method_npu.py | 82 ------------ python/sglang/srt/layers/quantization/fp8.py | 126 +++--------------- .../ascend/test_ascend_mxfp8_quantization.py | 103 -------------- 3 files changed, 18 insertions(+), 293 deletions(-) delete mode 100644 python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py delete mode 100644 test/srt/ascend/test_ascend_mxfp8_quantization.py diff --git a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py deleted file mode 100644 index c83b77cad7f4..000000000000 --- a/python/sglang/srt/hardware_backend/npu/quantization/mxfp8_method_npu.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Optional - -import torch -import torch_npu -from torch.nn.parameter import Parameter - - -class NPUMXFP8LinearMethod: - """Ascend NPU MXFP8 weight processing and kernel calls. - - This class handles NPU-specific operations: - - process_weights_after_loading: dtype casting and online quantization - - apply: dynamic activation quantization + MXFP8 matmul - - Weight creation and config management are handled by - MXFP8LinearAscendMethod in fp8.py. - """ - - MXFP8_BLOCK_SIZE = 32 - - def process_weights_after_loading( - self, layer: torch.nn.Module, is_serialized: bool - ): - if is_serialized: - # Checkpoint already has fp8 weights + uint8 scales. - # Ensure weight is float8_e4m3fn. - if layer.weight.data.dtype != torch.float8_e4m3fn: - layer.weight = Parameter( - torch_npu.npu_dtype_cast( - layer.weight.data, torch_npu.float8_e4m3fn - ), - requires_grad=False, - ) - else: - layer.weight.requires_grad_(False) - - # Scale is already uint8 (e8m0fnu), keep as-is. - layer.weight_scale_inv.requires_grad_(False) - else: - # Online quantization: quantize FP16/BF16 weights to MXFP8. - weight_fp = layer.weight.data - if weight_fp.dtype not in (torch.float16, torch.bfloat16): - weight_fp = weight_fp.to(torch.bfloat16) - - if not weight_fp.is_npu: - weight_fp = weight_fp.npu() - - qw, w_scale = torch_npu.npu_dynamic_mx_quant( - weight_fp, dst_type=torch_npu.float8_e4m3fn - ) - layer.weight = Parameter(qw, requires_grad=False) - layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - original_dtype = x.dtype - if original_dtype not in (torch.float16, torch.bfloat16): - x = x.to(torch.bfloat16) - original_dtype = torch.bfloat16 - - # Dynamic MXFP8 activation quantization (block_size=32) - qx, input_scale = torch_npu.npu_dynamic_mx_quant( - x, dst_type=torch_npu.float8_e4m3fn - ) - - # MXFP8 quantized matmul - output = torch_npu.npu_quant_matmul( - qx, - layer.weight.transpose(0, 1), - layer.weight_scale_inv.transpose(0, 1), - scale_dtype=torch_npu.float8_e8m0fnu, - pertoken_scale=input_scale, - pertoken_scale_dtype=torch_npu.float8_e8m0fnu, - bias=bias.to(torch.float32) if bias is not None else None, - output_dtype=original_dtype, - group_sizes=[1, 1, self.MXFP8_BLOCK_SIZE], - ) - return output diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b59cd86d6ba0..1e12baff13cf 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -57,10 +57,7 @@ requant_weight_ue8m0_inplace, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod -from sglang.srt.layers.quantization.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) +from sglang.srt.layers.quantization.marlin_utils_fp8 import prepare_fp8_layer_for_marlin from sglang.srt.layers.quantization.unquant import ( UnquantizedFusedMoEMethod, UnquantizedLinearMethod, @@ -225,8 +222,6 @@ def get_quant_method( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() - if _is_npu and self.use_mxfp8: - return MXFP8LinearAscendMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( @@ -244,99 +239,6 @@ def get_scaled_act_names(self) -> List[str]: return [] -class MXFP8LinearAscendMethod(LinearMethodBase): - """Ascend NPU MXFP8 (Microscaling FP8) quantization for Linear layers. - - Supports two modes: - - Online quantization: loads FP16/BF16 weights and quantizes them to MXFP8 - at weight loading time. - - Offline quantization: loads pre-quantized FP8 weights with block scales - from a serialized checkpoint. - - Weight creation is handled here; weight processing and kernel calls are - delegated to NPUMXFP8LinearMethod in the NPU hardware backend. - """ - - MXFP8_BLOCK_SIZE = 32 - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - from sglang.srt.hardware_backend.npu.quantization.mxfp8_method_npu import ( - NPUMXFP8LinearMethod, - ) - - self.npu_method = NPUMXFP8LinearMethod() - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") - - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - - is_serialized = self.quant_config.is_checkpoint_fp8_serialized - - # Weight: fp8 if serialized checkpoint, else original dtype (will be - # quantized in process_weights_after_loading) - weight_dtype = torch.float8_e4m3fn if is_serialized else params_dtype - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=weight_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - if is_serialized: - # Block scale: one scale per block of 32 elements along input dim. - # Stored as uint8 (representing float8_e8m0fnu) in checkpoint. - block_k = self.MXFP8_BLOCK_SIZE - scale_cols = (input_size_per_partition + block_k - 1) // block_k - scale = BlockQuantScaleParameter( - data=torch.zeros( - output_size_per_partition, - scale_cols, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - scale.format_ue8m0 = True - layer.register_parameter("weight_scale_inv", scale) - else: - layer.register_parameter("weight_scale_inv", None) - - def process_weights_after_loading(self, layer: torch.nn.Module): - self.npu_method.process_weights_after_loading( - layer, self.quant_config.is_checkpoint_fp8_serialized - ) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - return self.npu_method.apply(layer, x, bias) - - class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. @@ -741,7 +643,7 @@ def apply( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_marlin: - return apply_fp8_marlin_linear( + return torch.ops.sglang.apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, @@ -1096,15 +998,23 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: w2_weight_scale, requires_grad=False ) layer.w2_input_scale = None - - if _use_aiter: + if _use_aiter: + # add this section for MI300 + # Pre-shuffle weights + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) + elif _use_aiter: # Pre-shuffle weights - t = shuffle_weight(layer.w13_weight, (16, 16)) - layer.w13_weight.copy_(t) - del t - t = shuffle_weight(layer.w2_weight, (16, 16)) - layer.w2_weight.copy_(t) - del t + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) elif _is_cpu: assert ( _is_cpu_amx_available diff --git a/test/srt/ascend/test_ascend_mxfp8_quantization.py b/test/srt/ascend/test_ascend_mxfp8_quantization.py deleted file mode 100644 index e7af4ff97f1a..000000000000 --- a/test/srt/ascend/test_ascend_mxfp8_quantization.py +++ /dev/null @@ -1,103 +0,0 @@ -""" -Usage: -python3 -m unittest test_ascend_mxfp8_quantization.TestAscendMXFP8.test_gsm8k -""" - -import os -import time -import unittest -from types import SimpleNamespace -from urllib.parse import urlparse - -import requests - -from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - is_in_ci, - popen_launch_server, -) - -if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: - os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" -DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( - 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 -) -DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" - - -class TestAscendMXFP8(CustomTestCase): - """Test online MXFP8 quantization (--quantization mxfp8) on Ascend NPU.""" - - @classmethod - def setUpClass(cls): - cls.model = "Qwen/Qwen2.5-0.5B-Instruct" - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--trust-remote-code", - "--quantization", - "mxfp8", - "--device", - "npu", - "--attention-backend", - "ascend", - ], - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_gsm8k(self): - url = urlparse(self.base_url) - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=200, - max_new_tokens=512, - parallel=128, - host=f"http://{url.hostname}", - port=int(url.port), - ) - metrics = run_eval(args) - print(metrics) - self.assertGreaterEqual(metrics["accuracy"], 0.25) - self.assertGreaterEqual(metrics["output_throughput"], 500) - - def run_decode(self, max_new_tokens): - response = requests.post( - self.base_url + "/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - "ignore_eos": True, - }, - ) - return response.json() - - def test_throughput(self): - max_tokens = 256 - - tic = time.perf_counter() - res = self.run_decode(max_tokens) - tok = time.perf_counter() - print(res["text"]) - throughput = max_tokens / (tok - tic) - print(f"Throughput: {throughput} tokens/s") - - if is_in_ci(): - self.assertGreaterEqual(throughput, 20) - - -if __name__ == "__main__": - unittest.main() From 1101cf5d6ce628d77cc6061a6c4f2db6285f8830 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 31 Mar 2026 15:36:35 +0800 Subject: [PATCH 14/25] :adhesive_bandage: fix(loader): preserve --quantization flag priority in _resolve_quant_config Upstream refactored class methods to standalone functions in transformer_load_utils.py but dropped the server_args.quantization priority path. Re-add it so mxfp8/mxfp4/modelslim still work. --- .../runtime/loader/transformer_load_utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py index 23c60043ed2e..fd433d37c444 100644 --- a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py @@ -274,8 +274,17 @@ def _resolve_quant_config( ) -> Optional[QuantizationConfig]: """ resolve quant config from checkpoints' metadata - priority: model config.json -> safetensors metadata -> format-specific fallback + priority: explicit --quantization flag -> model config.json -> safetensors metadata -> format-specific fallback """ + # priority: explicit --quantization flag (e.g. mxfp8, mxfp4, modelslim) + if server_args.quantization is not None: + from sglang.multimodal_gen.runtime.layers.quantization import ( + get_quantization_config, + ) + + quant_cls = get_quantization_config(server_args.quantization) + return quant_cls.from_config({}) + quant_config = get_quant_config(hf_config, component_model_path) if quant_config is None and server_args.transformer_weights_path: for safetensors_file in safetensors_list: From f1c652b69782dc29b12af74cbfe36c0f1fed051a Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 1 Apr 2026 11:36:36 +0800 Subject: [PATCH 15/25] :sparkles: feat(npu/mxfp8): add W8A8 MXFP8 LLM support on Ascend NPU 1. Add NPUMXFP8LinearMethod to linear_method_npu.py (online quant)\n2. Add NPU dispatch branch in Fp8Config.get_quant_method\n3. Fix get_min_capability to return 0 on NPU\n4. Add ModelSlimMXFP8Scheme for offline pre-quantized MXFP8 weights\n5. Register W8A8_MXFP8 scheme in modelslim dispatcher --- .../npu/quantization/linear_method_npu.py | 98 ++++++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 8 ++ .../quantization/modelslim/modelslim.py | 6 + .../modelslim/schemes/__init__.py | 2 + .../modelslim/schemes/modelslim_mxfp8.py | 106 ++++++++++++++++++ 5 files changed, 220 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 788620a317bb..0f94b3364a71 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -1,6 +1,8 @@ from typing import TYPE_CHECKING, Optional import torch +import torch_npu +from torch.nn.parameter import Parameter from sglang.srt.hardware_backend.npu.utils import npu_format_cast from sglang.srt.layers.quantization.base_config import LinearMethodBase @@ -8,6 +10,8 @@ if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig +MXFP8_BLOCK_SIZE = 32 + class _NPULinearMethodBase(LinearMethodBase): @@ -111,6 +115,100 @@ def apply( ) +class NPUMXFP8LinearMethod(_NPULinearMethodBase): + """Ascend NPU MXFP8 linear method for LLM (SRT) models. + + Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. + Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). + """ + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.parameter import ModelWeightParameter + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move weight to NPU if needed (cpu offload may have moved it back to CPU) + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online MXFP8 quantisation of weights (block_size=32) + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] for npu_dynamic_mx_quant + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale_inv.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) + + class NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer): diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index d9ed7a08c25b..689e85534c1f 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -169,6 +169,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.half] def get_min_capability(self) -> int: + if is_npu(): + return 0 # NPU bypasses CUDA capability checks return 100 if self.use_mxfp8 else 80 @classmethod @@ -222,6 +224,12 @@ def get_quant_method( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() + if is_npu() and self.use_mxfp8: + from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUMXFP8LinearMethod, + ) + + return NPUMXFP8LinearMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 84acecccc415..37ffa43bfac6 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -190,6 +190,12 @@ def _get_scheme_from_parts( return ModelSlimW4A4Int4( quant_config=self.quant_description, prefix=layer_name ) + elif quant_type == "W8A8_MXFP8": + from sglang.srt.layers.quantization.modelslim.schemes.modelslim_mxfp8 import ( + ModelSlimMXFP8Scheme, + ) + + return ModelSlimMXFP8Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_linear_scheme( diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py index c349fd3c4251..849d65f918ae 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from .modelslim_mxfp8 import ModelSlimMXFP8Scheme from .modelslim_scheme import ModelSlimLinearScheme, ModelSlimMoEScheme from .modelslim_w4a4_int4 import ModelSlimW4A4Int4 from .modelslim_w4a4_int4_moe import ModelSlimW4A4Int4MoE @@ -10,6 +11,7 @@ __all__ = [ "ModelSlimLinearScheme", "ModelSlimMoEScheme", + "ModelSlimMXFP8Scheme", "ModelSlimW8A8Int8", "ModelSlimW4A4Int4", "ModelSlimW4A4Int4MoE", diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py new file mode 100644 index 000000000000..ddaa7f0e7f40 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py @@ -0,0 +1,106 @@ +"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU (SRT). + +Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, +uint8 scales) and runs MXFP8 matmul at inference. +""" + +from typing import List, Optional + +import torch +import torch_npu + +from sglang.srt.layers.parameter import GroupQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP8_BLOCK_SIZE = 32 + + +class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as float8_e4m3fn, shape [out, in] + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # NOTE: Named "weight_scale" (not "weight_scale_inv") to match the + # checkpoint key exported by msmodelslim. + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # weight is already float8_e4m3fn, no cast needed + weight = layer.weight.data + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + + # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] + # npu_quant_matmul expects the scale in paired-element format + weight_scale = layer.weight_scale.data + weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size] + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) From 97c45b602bc8cfb65f718d3fec714afe9a7efa96 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 1 Apr 2026 21:23:11 +0800 Subject: [PATCH 16/25] :recycle: refactor(npu/mxfp8): refactor code to align with vllm-ascend --- .../npu/quantization/linear_method_npu.py | 16 +++++++++------- .../modelslim/schemes/modelslim_mxfp8.py | 18 ++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 0f94b3364a71..de9593958934 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig MXFP8_BLOCK_SIZE = 32 +_FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)) class _NPULinearMethodBase(LinearMethodBase): @@ -168,8 +169,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: qw, w_scale = torch_npu.npu_dynamic_mx_quant( weight_fp, dst_type=torch_npu.float8_e4m3fn ) - layer.weight = Parameter(qw, requires_grad=False) - layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) + layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale.transpose(0, 1).contiguous(), requires_grad=False) def apply( self, @@ -191,14 +193,14 @@ def apply( x_2d, dst_type=torch_npu.float8_e4m3fn ) - # MXFP8 matmul + # MXFP8 matmul (weight & scale already transposed at load time) output = torch_npu.npu_quant_matmul( qx, - layer.weight.transpose(0, 1), - layer.weight_scale_inv.transpose(0, 1), - scale_dtype=torch_npu.float8_e8m0fnu, + layer.weight, + layer.weight_scale_inv, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, pertoken_scale=input_scale, - pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, bias=bias.to(torch.float32) if bias is not None else None, output_dtype=original_dtype, group_sizes=[1, 1, MXFP8_BLOCK_SIZE], diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py index ddaa7f0e7f40..796b7bdd3d2f 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py @@ -13,6 +13,7 @@ from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme MXFP8_BLOCK_SIZE = 32 +_FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)) class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): @@ -60,13 +61,14 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module): # weight is already float8_e4m3fn, no cast needed weight = layer.weight.data - layer.weight = torch.nn.Parameter(weight, requires_grad=False) + # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) + layer.weight = torch.nn.Parameter(weight.transpose(0, 1).contiguous(), requires_grad=False) # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] - # npu_quant_matmul expects the scale in paired-element format + # then transpose to [in/64, out, 2] for npu_quant_matmul weight_scale = layer.weight_scale.data weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) - layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + layer.weight_scale = torch.nn.Parameter(weight_scale.transpose(0, 1).contiguous(), requires_grad=False) def apply_weights( self, @@ -88,14 +90,14 @@ def apply_weights( x_2d, dst_type=torch_npu.float8_e4m3fn ) - # MXFP8 matmul + # MXFP8 matmul (weight & scale already transposed at load time) output = torch_npu.npu_quant_matmul( qx, - layer.weight.transpose(0, 1), - layer.weight_scale.transpose(0, 1), - scale_dtype=torch_npu.float8_e8m0fnu, + layer.weight, + layer.weight_scale, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, pertoken_scale=input_scale, - pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, bias=bias.to(torch.float32) if bias is not None else None, output_dtype=original_dtype, group_sizes=[1, 1, MXFP8_BLOCK_SIZE], From 6026a189c243c8df0e88daf2a7b79bc59067fafe Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 2 Apr 2026 14:39:24 +0800 Subject: [PATCH 17/25] =?UTF-8?q?=F0=9F=90=9B=20fix(quantization/modelslim?= =?UTF-8?q?):=20resolve=20circular=20import=20in=20schemes/=5F=5Finit=5F?= =?UTF-8?q?=5F.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move ModelSlimLinearScheme import before ModelSlimMXFP8Scheme to resolve circular dependency. modelslim_mxfp8.py imports ModelSlimLinearScheme from the schemes package, which failed when __init__.py hadn't yet defined it. The fix ensures base class is available in module namespace before any subclass imports attempt to reference it. Issue: ImportError: cannot import name 'ModelSlimLinearScheme' from partially initialized module 'sglang.srt.layers.quantization.modelslim.schemes' --- .../srt/layers/quantization/modelslim/schemes/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py index 849d65f918ae..bfc2a350c619 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py @@ -1,7 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -from .modelslim_mxfp8 import ModelSlimMXFP8Scheme +# NOTE: Import order is critical to avoid circular dependency. +# modelslim_mxfp8 imports ModelSlimLinearScheme from this package, +# so the base class must be imported first. +# isort: off from .modelslim_scheme import ModelSlimLinearScheme, ModelSlimMoEScheme +from .modelslim_mxfp8 import ModelSlimMXFP8Scheme + +# isort: on from .modelslim_w4a4_int4 import ModelSlimW4A4Int4 from .modelslim_w4a4_int4_moe import ModelSlimW4A4Int4MoE from .modelslim_w4a8_int8_moe import ModelSlimW4A8Int8MoE From 3025e2d463d77cbcc4afa71ced1d40f49b81a124 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 2 Apr 2026 16:54:42 +0800 Subject: [PATCH 18/25] :bug: fix(quantization/modelslim): fix no scheme found error. --- python/sglang/srt/models/transformers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/models/transformers.py b/python/sglang/srt/models/transformers.py index 0ea9da14a1be..36a9eb48b7e6 100644 --- a/python/sglang/srt/models/transformers.py +++ b/python/sglang/srt/models/transformers.py @@ -99,6 +99,7 @@ def replace_linear_class( linear: nn.Linear, style: Literal["colwise", "rowwise"], quant_config: QuantizationConfig, + prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. @@ -107,6 +108,7 @@ def replace_linear_class( linear (nn.Linear): `nn.Linear` to be replaced. style (str): Tensor parallel style of the new linear, e.g. "colwise". quant_config (QuantConfig): Quantization config for the new linear. + prefix (str): Layer name prefix used for per-layer quantization dispatch. Returns: Union[ColumnParallelLinear, RowParallelLinear]: The new linear. """ @@ -136,6 +138,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, + prefix=prefix, ) @@ -227,14 +230,14 @@ def _tensor_parallel(module: nn.Module, prefix: str = ""): child_module, nn.Linear ): new_module = replace_linear_class( - child_module, style, self.quant_config + child_module, style, self.quant_config, prefix=qual_name ) setattr(module, child_name, new_module) self.log_replacement(qual_name, child_module, new_module) else: _tensor_parallel(child_module, prefix=qual_name) - _tensor_parallel(self.model) + _tensor_parallel(self.model, prefix="model") def replace_vocab_embed_class(self, module: nn.Module): # Use native set input embeddings From da924189d833e50ecdc79be8fe39ac9661d359dc Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 3 Apr 2026 15:29:04 +0800 Subject: [PATCH 19/25] :bug: fix(llm/mxfp8): fix meaningless output issue --- .../modelslim/schemes/modelslim_mxfp8.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py index 796b7bdd3d2f..77702afa4db5 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py @@ -13,7 +13,9 @@ from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme MXFP8_BLOCK_SIZE = 32 -_FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)) +_FLOAT8_E8M0FNU_DTYPE = getattr( + torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) +) class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): @@ -60,15 +62,15 @@ def create_weights( def process_weights_after_loading(self, layer: torch.nn.Module): # weight is already float8_e4m3fn, no cast needed - weight = layer.weight.data # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) - layer.weight = torch.nn.Parameter(weight.transpose(0, 1).contiguous(), requires_grad=False) - - # Reshape weight_scale: [out, in/32] -> [out, in/64, 2] - # then transpose to [in/64, out, 2] for npu_quant_matmul - weight_scale = layer.weight_scale.data - weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) - layer.weight_scale = torch.nn.Parameter(weight_scale.transpose(0, 1).contiguous(), requires_grad=False) + # NOTE: Use .data in-place (no .contiguous()) to preserve the transpose view. + # npu_quant_matmul reads strides to understand the memory layout; calling + # .contiguous() would physically reorder data and break the block-scale + # mapping, producing garbled output. + n_dim, k_dim = layer.weight_scale.data.shape + layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) + layer.weight.data = layer.weight.data.transpose(0, 1) + layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) def apply_weights( self, From 29c04bc2fbdb4863b6f262d4d51e78cd8f56e6c3 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Tue, 7 Apr 2026 15:19:32 +0800 Subject: [PATCH 20/25] :bug: fix(llm/mxfp8): fix meaningless output issue --- .../quantization/modelslim/schemes/modelslim_mxfp8.py | 10 ++++------ python/sglang/srt/layers/rotary_embedding/base.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py index 77702afa4db5..02fd515db594 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py @@ -61,12 +61,10 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module): - # weight is already float8_e4m3fn, no cast needed - # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) - # NOTE: Use .data in-place (no .contiguous()) to preserve the transpose view. - # npu_quant_matmul reads strides to understand the memory layout; calling - # .contiguous() would physically reorder data and break the block-scale - # mapping, producing garbled output. + # Pre-transpose weight and scale to [in, out] for npu_quant_matmul. + # Use .data assignment without .contiguous() to preserve the transpose + # view strides — npu_quant_matmul reads strides correctly and calling + # .contiguous() would reorder data, breaking the block-scale mapping. n_dim, k_dim = layer.weight_scale.data.shape layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) layer.weight.data = layer.weight.data.transpose(0, 1) diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 518d250211f8..06ff843e1485 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -39,7 +39,11 @@ if _is_npu: import torch_npu - from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa + + try: + from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa + except ImportError: + fused_rope_qk_mqa = None if _is_hip: from sglang.srt.layers.attention.utils import ( @@ -257,7 +261,10 @@ def forward_npu( else: cos_sin = self.cos_sin_cache.index_select(0, positions) - if query.shape[0] * query.shape[1] < 65535: + if ( + fused_rope_qk_mqa is not None + and query.shape[0] * query.shape[1] < 65535 + ): return fused_rope_qk_mqa( query, key, From 66933523f6e7712a9e8f8f6d1a8431e6f2d136ce Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Wed, 8 Apr 2026 12:00:52 +0800 Subject: [PATCH 21/25] :sparkles: feat(npu-quant): add MXFP4 W4A8 online quantization for Qwen3 Dense LLM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements NPUMXFP4W4A8LinearMethod for Ascend NPU online quantization of dense Qwen3/3.5 LLM models triggered via --quantization mxfp4_npu. Weight flow (process_weights_after_loading): BF16/FP16 → npu_dynamic_dual_level_mx_quant → FP4 (NZ format) l0_scale [in/512, out] (FP32) + l1_scale (FP8_E8M0) Inference flow (apply): activation → npu_dynamic_dual_level_mx_quant → FP4 → npu_dual_level_quant_matmul (W4A4 compute with dual-level scales) Config registered as 'mxfp4_npu' in QUANTIZATION_METHODS. Hardware: requires Ascend 950 (DualLevelQuantBatchMatmul not on A2/A3). --- .../npu/quantization/linear_method_npu.py | 128 +++++++++++++++++- .../srt/layers/quantization/__init__.py | 2 + .../srt/layers/quantization/npu_mxfp4.py | 113 ++++++++++++++++ 3 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/npu_mxfp4.py diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index de9593958934..dc679c2c96ef 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -11,7 +11,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig MXFP8_BLOCK_SIZE = 32 -_FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)) +_FLOAT8_E8M0FNU_DTYPE = getattr( + torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) +) class _NPULinearMethodBase(LinearMethodBase): @@ -171,7 +173,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False) - layer.weight_scale_inv = Parameter(w_scale.transpose(0, 1).contiguous(), requires_grad=False) + layer.weight_scale_inv = Parameter( + w_scale.transpose(0, 1).contiguous(), requires_grad=False + ) def apply( self, @@ -211,6 +215,126 @@ def apply( return output.reshape(output_shape) +class NPUMXFP4W4A8LinearMethod(_NPULinearMethodBase): + """Ascend NPU W4A8 online quantization: MXFP4 weights + MXFP8 activations. + + Weight quantization flow (process_weights_after_loading): + BF16/FP16 weight → npu_dynamic_dual_level_mx_quant → FP4 + l0_scale(FP32) + l1_scale(FP8_E8M0) + → npu_format_cast to FRACTAL_NZ (required by npu_dual_level_quant_matmul) + → w_dual_scale transposed to [in/512, out] (required by matmul API) + + Inference flow (apply): + FP16/BF16 activation → npu_dynamic_dual_level_mx_quant → FP4 + act_l0_scale + act_l1_scale + → npu_dual_level_quant_matmul(FP4_act, FP4_weight, scales...) → FP16/BF16 output + + Note: The "A8" refers to the MXFP8 intermediate scale format (FP8_E8M0 l1_scale). + The actual matmul compute is W4A4 (both operands in FP4) since there is no + W4A8 mixed-precision kernel in the current torch_npu public API. + """ + + _FLOAT4_E2M1FN_X2_DTYPE = getattr( + torch_npu, "float4_e2m1fn_x2", getattr(torch, "float4_e2m1fn_x2", None) + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.parameter import ModelWeightParameter + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise to MXFP4 in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move to NPU if needed (cpu offload may have put it on CPU) + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online MXFP4 dual-level quantisation of weights + # qw: float4_e2m1fn_x2, shape [out, in] + # w_dual_scale: float32, shape [out, in/512, 1] (L0) + # w_scale: float8_e8m0, shape [out, (ceil(in/32)+1)//2, 2] (L1) + qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + weight_fp, smooth_scale=None + ) + + # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format=29) + # view as int8 first because npu_format_cast only accepts int-dtype tensors + qw = torch_npu.npu_format_cast(qw.view(torch.int8), 29) + + # npu_dual_level_quant_matmul expects x2_level0_scale shape [in/512, out]: + # squeeze the trailing dim-1 axis, then transpose + w_dual_scale = w_dual_scale.squeeze(-1).transpose(0, 1).contiguous() + + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_dual_scale = Parameter(w_dual_scale, requires_grad=False) + layer.weight_scale = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] for dual-level quant API + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP4 activation quantisation (W4 activations → A4 for matmul) + qx, act_l0_scale, act_l1_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + x_2d, smooth_scale=None + ) + + # MXFP4 matmul: W4A4 compute (weight already in NZ format + transposed scales) + output = torch_npu.npu_dual_level_quant_matmul( + qx, + layer.weight, + act_l0_scale, + layer.weight_dual_scale, + act_l1_scale, + layer.weight_scale, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) + + class NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer): diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 8a6b1b06e193..c830450f2fbe 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -36,6 +36,7 @@ def override_quantization_method(self, *args, **kwargs): from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config +from sglang.srt.layers.quantization.npu_mxfp4 import NPUMxfp4Config from sglang.srt.layers.quantization.petit import PetitNvFp4Config from sglang.srt.layers.quantization.qoq import QoQConfig from sglang.srt.layers.quantization.quark.quark import QuarkConfig @@ -77,6 +78,7 @@ def override_quantization_method(self, *args, **kwargs): "auto-round": AutoRoundConfig, "modelslim": ModelSlimConfig, "quark_int4fp8_moe": QuarkInt4Fp8Config, + "mxfp4_npu": NPUMxfp4Config, } diff --git a/python/sglang/srt/layers/quantization/npu_mxfp4.py b/python/sglang/srt/layers/quantization/npu_mxfp4.py new file mode 100644 index 000000000000..eb01af932e33 --- /dev/null +++ b/python/sglang/srt/layers/quantization/npu_mxfp4.py @@ -0,0 +1,113 @@ +"""Ascend NPU MXFP4 W4A8 online quantization config. + +Triggered by ``--quantization mxfp4_npu``. + +Online mode: loads FP16/BF16 weights, quantises to MXFP4 (dual-level) in +``process_weights_after_loading``. During inference, activations are +dynamically quantised to MXFP4 and ``npu_dual_level_quant_matmul`` is used +for the matrix multiply. + +Hardware requirement: Ascend 950 (DualLevelQuantBatchMatmul is NOT supported +on Atlas A2/A3 – check your hardware before enabling). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch + +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.unquant import ( + UnquantizedFusedMoEMethod, + UnquantizedLinearMethod, +) +from sglang.srt.layers.quantization.utils import is_layer_skipped + +if TYPE_CHECKING: + pass + + +class NPUMxfp4Config(QuantizationConfig): + """Quantization config for Ascend NPU MXFP4 W4A8 online quantization. + + Weights are quantised online to MXFP4 dual-level format during model + loading. Activations are quantised dynamically to MXFP4 at inference + time. The matmul is executed via ``torch_npu.npu_dual_level_quant_matmul``. + """ + + def __init__( + self, + ignored_layers: Optional[List[str]] = None, + packed_modules_mapping: Optional[Dict[str, str]] = None, + ): + super().__init__() + self.ignored_layers = ignored_layers or [] + self.packed_modules_mapping = packed_modules_mapping or {} + + @classmethod + def get_name(cls) -> str: + return "mxfp4_npu" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU bypasses CUDA capability checks + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict) -> "NPUMxfp4Config": + ignored_layers = cls.get_from_keys_or( + config, ["ignored_layers", "modules_to_not_convert"], None + ) + if ignored_layers: + normalized: List[str] = [] + for layer in ignored_layers: + base = layer.removeprefix("model.") + normalized.append(base) + normalized.append(f"model.{base}") + ignored_layers = normalized + packed_modules_mapping = ( + cls.get_from_keys_or(config, ["packed_modules_mapping"], {}) or {} + ) + return cls( + ignored_layers=ignored_layers, + packed_modules_mapping=packed_modules_mapping, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix, + self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUMXFP4W4A8LinearMethod, + ) + + return NPUMXFP4W4A8LinearMethod(self) + elif isinstance(layer, FusedMoE): + # MoE MXFP4 not yet implemented; fall back to unquantised + return UnquantizedFusedMoEMethod( + layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe + ) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] From e9dec3c1a1267b145eab5286cc813f32a541ede0 Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Thu, 16 Apr 2026 15:03:15 +0800 Subject: [PATCH 22/25] :sparkles: feat(npu/w4a8): add W4A8 offline Dense scheme and guards 1. Add ModelSlimW4A8Int8 + NPUW4A8DynamicLinearMethod for W4A8_DYNAMIC dispatch 2. Add hardware warning + try/except guard in NPUMXFP4W4A8LinearMethod (Ascend 950 required) 3. Add MoE unquantized fallback warning in NPUMxfp4Config --- .../npu/quantization/linear_method_npu.py | 146 +++++++++++++++++- .../quantization/modelslim/modelslim.py | 5 + .../modelslim/schemes/__init__.py | 2 + .../modelslim/schemes/modelslim_w4a8_int8.py | 131 ++++++++++++++++ .../srt/layers/quantization/npu_mxfp4.py | 9 ++ 5 files changed, 290 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8.py diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index dc679c2c96ef..393c966c06fc 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -270,6 +270,24 @@ def create_weights( layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + import logging + + from sglang.srt.utils import get_npu_memory_capacity + + _logger = logging.getLogger(__name__) + + # Heuristic hardware check: npu_dynamic_dual_level_mx_quant requires Ascend 950. + # Atlas A2/A3 have ≤64 GB per card; Ascend 950 has ≥96 GB per card. + npu_mem_mb = get_npu_memory_capacity() + if npu_mem_mb < 96 * 1024: + _logger.warning( + "MXFP4 W4A8 dual-level quantization may not be supported on this " + "hardware (detected NPU memory %.1f GB < 96 GB). " + "npu_dynamic_dual_level_mx_quant requires Ascend 950 (Atlas A3). " + "Continuing — expect a RuntimeError if the kernel is unavailable.", + npu_mem_mb / 1024, + ) + weight_fp = layer.weight.data if weight_fp.dtype not in (torch.float16, torch.bfloat16): weight_fp = weight_fp.to(torch.bfloat16) @@ -282,9 +300,17 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # qw: float4_e2m1fn_x2, shape [out, in] # w_dual_scale: float32, shape [out, in/512, 1] (L0) # w_scale: float8_e8m0, shape [out, (ceil(in/32)+1)//2, 2] (L1) - qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( - weight_fp, smooth_scale=None - ) + try: + qw, w_dual_scale, w_scale = torch_npu.npu_dynamic_dual_level_mx_quant( + weight_fp, smooth_scale=None + ) + except (RuntimeError, AttributeError) as e: + raise RuntimeError( + "npu_dynamic_dual_level_mx_quant failed — this operation requires " + "Ascend 950 (Atlas A3). Atlas 800I A2/A3 and earlier chips do NOT " + "support DualLevelQuantBatchMatmul. " + f"Original error: {e}" + ) from e # npu_dual_level_quant_matmul requires x2 in FRACTAL_NZ format (format=29) # view as int8 first because npu_format_cast only accepts int-dtype tensors @@ -335,6 +361,120 @@ def apply( return output.reshape(output_shape) +class NPUW4A8DynamicLinearMethod(_NPULinearMethodBase): + """Ascend NPU W4A8 offline quantization linear method. + + Offline mode: loads ModelSlim pre-quantized INT4 weights. + For ``new_quant_version=True`` (version "1.0.0"): 2 int4 values are pre-packed + into 1 int8 in the checkpoint (shape ``[N/2, K]``). + For old version: plain int4 stored as int8 (shape ``[N, K]``). + + Uses ``torch_npu.npu_weight_quant_batchmatmul`` for inference — activations + stay in high precision and INT4 weights are dequantized on-the-fly. + """ + + def __init__( + self, + group_size: int = 256, + new_quant_version: bool = True, + ): + super().__init__() + self.group_size = group_size + self.new_quant_version = new_quant_version + + @staticmethod + def _process_scale_second( + weight: torch.Tensor, + scale: torch.Tensor, + per_group_scale: torch.Tensor, + is_new_quant: bool = False, + ): + """Merge per-channel (L1) and per-group (L2) scales into antiquant_scale. + + Args: + weight: weight after transpose, shape ``[K, N/2]`` (new) or ``[K, N]`` (old) + scale: per-channel L1 scale, shape ``[N]`` + per_group_scale: per-group L2 scale after transpose, shape ``[K//group_size, N]`` + is_new_quant: whether weight dim is compressed (N/2) + + Returns: + (antiquant_scale, bias): ``antiquant_scale`` shape ``[K//group_size, N]``; + ``bias`` is non-None only for old version (asymmetric compensation term). + """ + k, n_compressed = weight.shape + group_num, n_scale = per_group_scale.shape + + # Logical N dimension + n = n_compressed * 2 if is_new_quant else n_compressed + + bias = None + if not is_new_quant: + weight_high = weight.to(torch.float32).reshape( + group_num, -1, n + ) * per_group_scale.reshape(group_num, 1, n) + weight_high = weight_high.reshape(k, n) + bias = 8 * (weight_high.to(torch.float32) * scale).sum(dim=0) + + antiquant_scale = (scale * per_group_scale).reshape(group_num, n) + return antiquant_scale, bias + + def process_weights_after_loading(self, layer: torch.nn.Module): + from sglang.srt.hardware_backend.npu.utils import npu_format_cast + + # Transpose [N, K] → [K, N] (or [N/2, K] → [K, N/2] for packed) + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + # Cast to FRACTAL_NZ format for NPU matmul efficiency + layer.weight.data = npu_format_cast(layer.weight.data) + + # Flatten per-channel scales to 1-D float32 + layer.weight_scale.data = layer.weight_scale.data.flatten().to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + + # Merge L1/L2 scales: weight_scale_second loaded as [N, K//group_size], + # transpose to [K//group_size, N] for process_scale_second + layer.weight_scale_second.data, scale_bias = self._process_scale_second( + layer.weight.data, + layer.weight_scale.data, + layer.weight_scale_second.data.transpose(0, 1).contiguous(), + is_new_quant=self.new_quant_version, + ) + + if self.new_quant_version: + # Handle optional scale_bias parameter + if hasattr(layer, "scale_bias"): + if layer.scale_bias.data.shape[1] == 1: + layer.scale_bias.data = layer.scale_bias.data.flatten() + else: + layer.scale_bias.data = layer.scale_bias.data.contiguous() + # Pack 4 int8 (2×int4) into int32 for NPU kernel + assert ( + layer.weight.data.shape[-1] % 4 == 0 + ), f"Last dim of weight must be divisible by 4, got {layer.weight.data.shape}" + layer.weight.data = layer.weight.data.view(torch.int32).contiguous() + else: + # Old version: use NPU int4-pack conversion + if scale_bias is not None: + param = torch.nn.Parameter(scale_bias, requires_grad=False) + layer.register_parameter("weight_scale_bias", param) + layer.weight.data = torch_npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32) + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Weight-dequant path: INT4 weights dequantized on-the-fly, activations in high precision + return torch_npu.npu_weight_quant_batchmatmul( + x, + layer.weight, + antiquant_scale=layer.weight_scale_second.to(x.dtype), + antiquant_group_size=self.group_size, + ) + + class NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer): diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 37ffa43bfac6..c1b66bf2f743 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -17,6 +17,7 @@ from sglang.srt.layers.quantization.modelslim.schemes import ( ModelSlimW4A4Int4, ModelSlimW4A4Int4MoE, + ModelSlimW4A8Int8, ModelSlimW4A8Int8MoE, ModelSlimW8A8Int8, ModelSlimW8A8Int8MoE, @@ -190,6 +191,10 @@ def _get_scheme_from_parts( return ModelSlimW4A4Int4( quant_config=self.quant_description, prefix=layer_name ) + elif quant_type == "W4A8_DYNAMIC": + return ModelSlimW4A8Int8( + quant_config=self.quant_description, prefix=layer_name + ) elif quant_type == "W8A8_MXFP8": from sglang.srt.layers.quantization.modelslim.schemes.modelslim_mxfp8 import ( ModelSlimMXFP8Scheme, diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py index bfc2a350c619..30ea1d04eb48 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py @@ -10,6 +10,7 @@ # isort: on from .modelslim_w4a4_int4 import ModelSlimW4A4Int4 from .modelslim_w4a4_int4_moe import ModelSlimW4A4Int4MoE +from .modelslim_w4a8_int8 import ModelSlimW4A8Int8 from .modelslim_w4a8_int8_moe import ModelSlimW4A8Int8MoE from .modelslim_w8a8_int8 import ModelSlimW8A8Int8 from .modelslim_w8a8_int8_moe import ModelSlimW8A8Int8MoE @@ -21,6 +22,7 @@ "ModelSlimW8A8Int8", "ModelSlimW4A4Int4", "ModelSlimW4A4Int4MoE", + "ModelSlimW4A8Int8", "ModelSlimW4A8Int8MoE", "ModelSlimW8A8Int8MoE", ] diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8.py new file mode 100644 index 000000000000..4a07ffb67746 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a8_int8.py @@ -0,0 +1,131 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch + +from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUW4A8DynamicLinearMethod, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme +from sglang.srt.utils import set_weight_attrs + + +class ModelSlimW4A8Int8(ModelSlimLinearScheme): + """ModelSlim offline W4A8 Dense Linear scheme. + + Handles ``W4A8_DYNAMIC`` quant_type from ``quant_model_description.json``. + + Weight layout in the checkpoint: + - ``new_quant_version`` (version == "1.0.0"): INT4×2 pre-packed into INT8, + so on-disk shape is ``[N/2, K]``. + - Old version: each INT8 stores one INT4, on-disk shape is ``[N, K]``. + + Delegates weight processing and matmul to ``NPUW4A8DynamicLinearMethod`` + which uses ``torch_npu.npu_weight_quant_batchmatmul`` (weight-dequant path). + """ + + def __init__( + self, + quant_config: Dict[str, Any], + prefix: str, + ): + self.quant_config = quant_config + self.group_size: int = quant_config.get("group_size", 256) + self.new_quant_version: bool = quant_config.get("version", "0") == "1.0.0" + self.kernel = NPUW4A8DynamicLinearMethod( + group_size=self.group_size, + new_quant_version=self.new_quant_version, + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # ── Weight ────────────────────────────────────────────────────────── + # new_quant_version: 2 INT4 packed per INT8 → shape [N/2, K] + # old version : 1 INT4 per INT8 → shape [N, K] + weight_n = ( + output_size_per_partition // 2 + if self.new_quant_version + else output_size_per_partition + ) + weight = torch.nn.Parameter( + torch.empty(weight_n, input_size_per_partition, dtype=torch.int8), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + # ── Per-channel L1 scale & offset: [N, 1] ─────────────────────────── + weight_scale = torch.nn.Parameter( + torch.empty(output_size_per_partition, 1, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_scale, {"output_dim": 0}) + layer.register_parameter("weight_scale", weight_scale) + set_weight_attrs(weight_scale, extra_weight_attrs) + + weight_offset = torch.nn.Parameter( + torch.empty(output_size_per_partition, 1, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_offset, {"output_dim": 0}) + layer.register_parameter("weight_offset", weight_offset) + set_weight_attrs(weight_offset, extra_weight_attrs) + + # ── Per-group L2 scale & offset: [N, K//group_size] ───────────────── + # Note: for RowParallelLinear (K partitioned), input_dim=1 would be needed; + # for ColumnParallelLinear (N partitioned), output_dim=0 suffices. + # Initial implementation covers the column-parallel case. + group_num = input_size_per_partition // self.group_size + weight_scale_second = torch.nn.Parameter( + torch.empty(output_size_per_partition, group_num, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_scale_second, {"output_dim": 0}) + layer.register_parameter("weight_scale_second", weight_scale_second) + set_weight_attrs(weight_scale_second, extra_weight_attrs) + + weight_offset_second = torch.nn.Parameter( + torch.empty(output_size_per_partition, group_num, dtype=params_dtype), + requires_grad=False, + ) + set_weight_attrs(weight_offset_second, {"output_dim": 0}) + layer.register_parameter("weight_offset_second", weight_offset_second) + set_weight_attrs(weight_offset_second, extra_weight_attrs) + + # ── scale_bias (new_quant_version only): [N, 1] ───────────────────── + # Shape is [N, 16] for RowParallelLinear (down_proj / o_proj), + # but [N, 1] for ColumnParallelLinear. Using [N, 1] for simplicity; + # process_weights_after_loading handles both shapes dynamically. + if self.new_quant_version: + scale_bias = torch.nn.Parameter( + torch.empty(output_size_per_partition, 1, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs(scale_bias, {"output_dim": 0}) + layer.register_parameter("scale_bias", scale_bias) + set_weight_attrs(scale_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.kernel.apply(layer, x, bias) diff --git a/python/sglang/srt/layers/quantization/npu_mxfp4.py b/python/sglang/srt/layers/quantization/npu_mxfp4.py index eb01af932e33..203812a1127d 100644 --- a/python/sglang/srt/layers/quantization/npu_mxfp4.py +++ b/python/sglang/srt/layers/quantization/npu_mxfp4.py @@ -13,6 +13,7 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Dict, List, Optional import torch @@ -30,6 +31,8 @@ if TYPE_CHECKING: pass +logger = logging.getLogger(__name__) + class NPUMxfp4Config(QuantizationConfig): """Quantization config for Ascend NPU MXFP4 W4A8 online quantization. @@ -104,6 +107,12 @@ def get_quant_method( return NPUMXFP4W4A8LinearMethod(self) elif isinstance(layer, FusedMoE): # MoE MXFP4 not yet implemented; fall back to unquantised + logger.warning( + "MXFP4 W4A8 quantization is not yet supported for FusedMoE layers " + "(prefix=%s). Falling back to unquantized MoE — MoE weights will " + "run in full precision (BF16/FP16).", + prefix, + ) return UnquantizedFusedMoEMethod( layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe ) From 066becd1c878031d2223071d66d5ec6f132fb13e Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 24 Apr 2026 17:34:43 +0800 Subject: [PATCH 23/25] :recycle: refactor(quant): rename mxfp4_npu to mxfp4_w4a8_npu and add mxfp4_w4a4_npu --- python/sglang/srt/layers/quantization/__init__.py | 2 +- python/sglang/srt/layers/quantization/npu_mxfp4.py | 4 ++-- python/sglang/srt/server_args.py | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index c830450f2fbe..d07117c8a7cd 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -78,7 +78,7 @@ def override_quantization_method(self, *args, **kwargs): "auto-round": AutoRoundConfig, "modelslim": ModelSlimConfig, "quark_int4fp8_moe": QuarkInt4Fp8Config, - "mxfp4_npu": NPUMxfp4Config, + "mxfp4_w4a8_npu": NPUMxfp4Config, } diff --git a/python/sglang/srt/layers/quantization/npu_mxfp4.py b/python/sglang/srt/layers/quantization/npu_mxfp4.py index 203812a1127d..a62526a85be1 100644 --- a/python/sglang/srt/layers/quantization/npu_mxfp4.py +++ b/python/sglang/srt/layers/quantization/npu_mxfp4.py @@ -1,6 +1,6 @@ """Ascend NPU MXFP4 W4A8 online quantization config. -Triggered by ``--quantization mxfp4_npu``. +Triggered by ``--quantization mxfp4_w4a8_npu``. Online mode: loads FP16/BF16 weights, quantises to MXFP4 (dual-level) in ``process_weights_after_loading``. During inference, activations are @@ -53,7 +53,7 @@ def __init__( @classmethod def get_name(cls) -> str: - return "mxfp4_npu" + return "mxfp4_w4a8_npu" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d6a746ee3b3b..325b97d9d389 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -116,6 +116,8 @@ "auto-round", "compressed-tensors", # for Ktransformers "modelslim", # for NPU + "mxfp4_w4a8_npu", # for NPU W4A8 + "mxfp4_w4a4_npu", # for NPU W4A4 "quark_int4fp8_moe", ] From 93d542a769c692b18a2c80afebe9e4582f71c91d Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 24 Apr 2026 18:02:41 +0800 Subject: [PATCH 24/25] :sparkles: feat(modelslim): add W4A8_MXFP offline scheme for LLM --- .../quantization/modelslim/modelslim.py | 6 + .../modelslim/schemes/__init__.py | 2 + .../modelslim/schemes/modelslim_mxfp4_w4a8.py | 117 ++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index c1b66bf2f743..c6c2e8f81f3e 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -201,6 +201,12 @@ def _get_scheme_from_parts( ) return ModelSlimMXFP8Scheme() + elif quant_type == "W4A8_MXFP": + from sglang.srt.layers.quantization.modelslim.schemes.modelslim_mxfp4_w4a8 import ( + ModelSlimMXFP4W4A8Scheme, + ) + + return ModelSlimMXFP4W4A8Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_linear_scheme( diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py index 30ea1d04eb48..c7755427a946 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py @@ -6,6 +6,7 @@ # isort: off from .modelslim_scheme import ModelSlimLinearScheme, ModelSlimMoEScheme from .modelslim_mxfp8 import ModelSlimMXFP8Scheme +from .modelslim_mxfp4_w4a8 import ModelSlimMXFP4W4A8Scheme # isort: on from .modelslim_w4a4_int4 import ModelSlimW4A4Int4 @@ -19,6 +20,7 @@ "ModelSlimLinearScheme", "ModelSlimMoEScheme", "ModelSlimMXFP8Scheme", + "ModelSlimMXFP4W4A8Scheme", "ModelSlimW8A8Int8", "ModelSlimW4A4Int4", "ModelSlimW4A4Int4MoE", diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py new file mode 100644 index 000000000000..dc1a52aae936 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py @@ -0,0 +1,117 @@ +"""ModelSlim W4A8_MXFP scheme for pre-quantized weight inference on Ascend NPU (SRT). + +Loads weights pre-quantized by msmodelslim (FP4 weights packed as uint8, uint8 +bias-shifted MXFP8_E8M0 scales) and runs W4A8 matmul at inference. + +Weight format exported by msmodelslim (on_w4a8_mx_dynamic_per_block): + weight: pack_fp4_to_uint8 → uint8, shape [out, in/2], group_size=32 + weight_scale: (scale + 127).uint8 shape [out, in/32] + +Inference: + activation → npu_dynamic_mx_quant(float8_e4m3fn) → qx + per-token scale + npu_quant_matmul(qx, weight, weight_scale, x1_dtype=fp8, x2_dtype=fp4) +""" + +from typing import List, Optional + +import torch +import torch_npu + +from sglang.srt.layers.parameter import GroupQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP4_W4A8_BLOCK_SIZE = 32 + +_FLOAT8_E8M0FNU_DTYPE = getattr( + torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) +) +_FLOAT4_E2M1FN_X2_DTYPE = getattr( + torch_npu, "float4_e2m1fn_x2", getattr(torch, "float4_e2m1fn_x2", None) +) + + +class ModelSlimMXFP4W4A8Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim packs 2 FP4 values per uint8 → shape [out, in/2] + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition // 2), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8 with +127 bias, shape [out, in/32] + scale_dim = input_size_per_partition // MXFP4_W4A8_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # Same transform as ModelSlimMXFP8Scheme: + # weight_scale: [out, in/32] → reshape [out, in/64, 2] → transpose [in/64, out, 2] + # weight: [out, in/2] → transpose [in/2, out] + n_dim, k_dim = layer.weight_scale.data.shape + layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) + layer.weight.data = layer.weight.data.transpose(0, 1) + layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamically quantize activations to FP8 (A8 in W4A8) + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch.float8_e4m3fn + ) + + # W4A8 matmul: FP8 activations × FP4 weights (already transposed at load time) + output = torch_npu.npu_quant_matmul( + qx, + layer.weight, + layer.weight_scale, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + pertoken_scale=input_scale, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + x1_dtype=torch.float8_e4m3fn, + x2_dtype=_FLOAT4_E2M1FN_X2_DTYPE, + group_sizes=[1, 1, MXFP4_W4A8_BLOCK_SIZE], + ) + + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) From 456bd1436059ef79313609e056c2a7e9e02514bf Mon Sep 17 00:00:00 2001 From: "(Messi) Junlin Wu" Date: Fri, 24 Apr 2026 18:13:54 +0800 Subject: [PATCH 25/25] :bug: fix(modelslim): fix W4A8_MXFP weight dtype to float8_e4m3fn --- .../modelslim/schemes/modelslim_mxfp4_w4a8.py | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py index dc1a52aae936..1e4913d989f2 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py @@ -1,15 +1,12 @@ """ModelSlim W4A8_MXFP scheme for pre-quantized weight inference on Ascend NPU (SRT). -Loads weights pre-quantized by msmodelslim (FP4 weights packed as uint8, uint8 -bias-shifted MXFP8_E8M0 scales) and runs W4A8 matmul at inference. - -Weight format exported by msmodelslim (on_w4a8_mx_dynamic_per_block): - weight: pack_fp4_to_uint8 → uint8, shape [out, in/2], group_size=32 - weight_scale: (scale + 127).uint8 shape [out, in/32] +Loads weights pre-quantized by msmodelslim: + weight: float8_e4m3fn, shape [out, in], group_size=32 + weight_scale: uint8 (+127 biased), shape [out, in/32] Inference: activation → npu_dynamic_mx_quant(float8_e4m3fn) → qx + per-token scale - npu_quant_matmul(qx, weight, weight_scale, x1_dtype=fp8, x2_dtype=fp4) + npu_quant_matmul(qx, weight, weight_scale, scale_dtype=FP8_E8M0) """ from typing import List, Optional @@ -25,9 +22,6 @@ _FLOAT8_E8M0FNU_DTYPE = getattr( torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) ) -_FLOAT4_E2M1FN_X2_DTYPE = getattr( - torch_npu, "float4_e2m1fn_x2", getattr(torch, "float4_e2m1fn_x2", None) -) class ModelSlimMXFP4W4A8Scheme(ModelSlimLinearScheme): @@ -45,11 +39,10 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) - # msmodelslim packs 2 FP4 values per uint8 → shape [out, in/2] weight = ModelWeightParameter( data=torch.empty( - (output_size_per_partition, input_size_per_partition // 2), - dtype=torch.uint8, + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, ), input_dim=1, output_dim=0, @@ -57,7 +50,6 @@ def create_weights( ) layer.register_parameter("weight", weight) - # msmodelslim exports weight_scale as uint8 with +127 bias, shape [out, in/32] scale_dim = input_size_per_partition // MXFP4_W4A8_BLOCK_SIZE weight_scale = GroupQuantScaleParameter( data=torch.empty( @@ -71,9 +63,8 @@ def create_weights( layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module): - # Same transform as ModelSlimMXFP8Scheme: # weight_scale: [out, in/32] → reshape [out, in/64, 2] → transpose [in/64, out, 2] - # weight: [out, in/2] → transpose [in/2, out] + # weight: [out, in] → transpose [in, out] n_dim, k_dim = layer.weight_scale.data.shape layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) layer.weight.data = layer.weight.data.transpose(0, 1) @@ -93,12 +84,10 @@ def apply_weights( input_shape = x.shape x_2d = x.reshape(-1, x.shape[-1]) - # Dynamically quantize activations to FP8 (A8 in W4A8) qx, input_scale = torch_npu.npu_dynamic_mx_quant( - x_2d, dst_type=torch.float8_e4m3fn + x_2d, dst_type=torch_npu.float8_e4m3fn ) - # W4A8 matmul: FP8 activations × FP4 weights (already transposed at load time) output = torch_npu.npu_quant_matmul( qx, layer.weight, @@ -108,8 +97,6 @@ def apply_weights( pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, bias=bias.to(torch.float32) if bias is not None else None, output_dtype=original_dtype, - x1_dtype=torch.float8_e4m3fn, - x2_dtype=_FLOAT4_E2M1FN_X2_DTYPE, group_sizes=[1, 1, MXFP4_W4A8_BLOCK_SIZE], )