diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 788620a317bb..de9593958934 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -1,6 +1,8 @@ from typing import TYPE_CHECKING, Optional import torch +import torch_npu +from torch.nn.parameter import Parameter from sglang.srt.hardware_backend.npu.utils import npu_format_cast from sglang.srt.layers.quantization.base_config import LinearMethodBase @@ -8,6 +10,9 @@ if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig +MXFP8_BLOCK_SIZE = 32 +_FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)) + class _NPULinearMethodBase(LinearMethodBase): @@ -111,6 +116,101 @@ def apply( ) +class NPUMXFP8LinearMethod(_NPULinearMethodBase): + """Ascend NPU MXFP8 linear method for LLM (SRT) models. + + Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. + Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). + """ + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.parameter import ModelWeightParameter + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move weight to NPU if needed (cpu offload may have moved it back to CPU) + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online MXFP8 quantisation of weights (block_size=32) + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) + layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale.transpose(0, 1).contiguous(), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] for npu_dynamic_mx_quant + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul (weight & scale already transposed at load time) + output = torch_npu.npu_quant_matmul( + qx, + layer.weight, + layer.weight_scale_inv, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + pertoken_scale=input_scale, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) + + class NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer): diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 8d7dfa2d3661..800080d2b69e 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -202,6 +202,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.half] def get_min_capability(self) -> int: + if is_npu(): + return 0 # NPU bypasses CUDA capability checks if _is_musa: return 31 @@ -258,6 +260,12 @@ def get_quant_method( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() + if is_npu() and self.use_mxfp8: + from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUMXFP8LinearMethod, + ) + + return NPUMXFP8LinearMethod(self) return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): if is_layer_skipped( diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 3d0c9079afd5..b08dcac7972a 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -14,6 +14,7 @@ QuantizationConfig, ) from sglang.srt.layers.quantization.modelslim.schemes import ( + ModelSlimMXFP8Scheme, ModelSlimW4A4Int4, ModelSlimW4A4Int4MoE, ModelSlimW4A8Int8MoE, @@ -180,6 +181,7 @@ def get_linear_scheme( ("W4A4_DYNAMIC", ModelSlimW4A4Int4), ("W8A8", ModelSlimW8A8Int8), ("W8A8_DYNAMIC", ModelSlimW8A8Int8), + ("W8A8_MXFP8", ModelSlimMXFP8Scheme), ] quant_schemes = [self.quant_description.get(prefix + ".weight", "")] diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py index c349fd3c4251..bfc2a350c619 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py @@ -1,6 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +# NOTE: Import order is critical to avoid circular dependency. +# modelslim_mxfp8 imports ModelSlimLinearScheme from this package, +# so the base class must be imported first. +# isort: off from .modelslim_scheme import ModelSlimLinearScheme, ModelSlimMoEScheme +from .modelslim_mxfp8 import ModelSlimMXFP8Scheme + +# isort: on from .modelslim_w4a4_int4 import ModelSlimW4A4Int4 from .modelslim_w4a4_int4_moe import ModelSlimW4A4Int4MoE from .modelslim_w4a8_int8_moe import ModelSlimW4A8Int8MoE @@ -10,6 +17,7 @@ __all__ = [ "ModelSlimLinearScheme", "ModelSlimMoEScheme", + "ModelSlimMXFP8Scheme", "ModelSlimW8A8Int8", "ModelSlimW4A4Int4", "ModelSlimW4A4Int4MoE", diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py new file mode 100644 index 000000000000..9a56875128c8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py @@ -0,0 +1,118 @@ +"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU (SRT). + +Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, +uint8 scales) and runs MXFP8 matmul at inference. +""" + +from typing import Dict, List, Optional + +import torch +import torch_npu + +from sglang.srt.layers.parameter import GroupQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP8_BLOCK_SIZE = 32 +_FLOAT8_E8M0FNU_DTYPE = getattr( + torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None) +) + + +class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): + + def __init__( + self, + quant_config: Optional[Dict[str, any]] = None, + prefix: Optional[str] = None, + ): + # quant_config / prefix are accepted to match the linear-scheme + # dispatch signature used by ModelSlimConfig.get_linear_scheme; + # MXFP8 needs no per-layer config beyond what create_weights derives. + del quant_config, prefix + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as float8_e4m3fn, shape [out, in] + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # NOTE: Named "weight_scale" (not "weight_scale_inv") to match the + # checkpoint key exported by msmodelslim. + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # Pre-transpose weight and scale to [in, out] for npu_quant_matmul. + # Use .data assignment without .contiguous() to preserve the transpose + # view strides — npu_quant_matmul reads strides correctly and calling + # .contiguous() would reorder data, breaking the block-scale mapping. + n_dim, k_dim = layer.weight_scale.data.shape + layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2) + layer.weight.data = layer.weight.data.transpose(0, 1) + layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size] + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul (weight & scale already transposed at load time) + output = torch_npu.npu_quant_matmul( + qx, + layer.weight, + layer.weight_scale, + scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + pertoken_scale=input_scale, + pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + return output.reshape(output_shape) diff --git a/python/sglang/srt/layers/rotary_embedding/base.py b/python/sglang/srt/layers/rotary_embedding/base.py index 2b13c1594d82..cc0d501b016b 100644 --- a/python/sglang/srt/layers/rotary_embedding/base.py +++ b/python/sglang/srt/layers/rotary_embedding/base.py @@ -39,7 +39,11 @@ if _is_npu: import torch_npu - from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa + + try: + from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa + except ImportError: + fused_rope_qk_mqa = None if _is_hip: from sglang.srt.layers.attention.utils import ( @@ -267,7 +271,10 @@ def forward_npu( else: cos_sin = self.cos_sin_cache.index_select(0, positions) - if query.shape[0] * query.shape[1] < 65535: + if ( + fused_rope_qk_mqa is not None + and query.shape[0] * query.shape[1] < 65535 + ): return fused_rope_qk_mqa( query, key,