Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
ef874c0
✨ feat(npu): add online MXFP8 quantization support for Ascend NPU (Pa…
TallMessiWu Mar 18, 2026
d2d19c6
✨ feat(diffusion): add online MXFP8 quantization support for Wan2.2 o…
TallMessiWu Mar 18, 2026
c838ade
:bug: fix(diffusion): fix npu method call error
TallMessiWu Mar 19, 2026
be3b684
:bug: fix(diffusion): fix MXFP8 quantization scale dimension mismatch…
TallMessiWu Mar 19, 2026
fd79b23
:recycle: refactor(mxfp8): split linear method into config and NPU la…
TallMessiWu Mar 20, 2026
df61b29
:twisted_rightwards_arrows: merge: sync from upstream
TallMessiWu Mar 20, 2026
490ad0b
:sparkles: feat(diffusion): add offline MXFP8 pre-quantized weight su…
TallMessiWu Mar 20, 2026
cc80690
:bug: fix(diffusion): correct MXFP8 weight dtype and scale shape
TallMessiWu Mar 23, 2026
b9aa785
✨ feat(wan22): Redesigned the wan_repack tool. Now support one-click …
TallMessiWu Mar 24, 2026
22bee9e
:recycle: refactor(mxfp8): hoist imports and replace print with logger
TallMessiWu Mar 24, 2026
a29bb3d
:pencil2: fix(diffusion/mxfp8): address review comments on ModelSlimM…
TallMessiWu Mar 25, 2026
3bbf703
:twisted_rightwards_arrows: chore(merge): sync upstream/main, keep MX…
TallMessiWu Mar 25, 2026
250fe65
:adhesive_bandage: fix(diffusion): register --quantization CLI arg to…
TallMessiWu Mar 25, 2026
e146b03
:bug: fix(mxfp8_npu): move weight to current NPU device before quanti…
TallMessiWu Mar 25, 2026
711bb8b
:rewind: revert(llm): remove LLM MXFP8 online quantization (Path B) f…
TallMessiWu Mar 25, 2026
1604d4e
:twisted_rightwards_arrows: chore(merge): sync upstream/main into junlin
TallMessiWu Mar 31, 2026
1101cf5
:adhesive_bandage: fix(loader): preserve --quantization flag priority…
TallMessiWu Mar 31, 2026
f1c652b
:sparkles: feat(npu/mxfp8): add W8A8 MXFP8 LLM support on Ascend NPU
TallMessiWu Apr 1, 2026
97c45b6
:recycle: refactor(npu/mxfp8): refactor code to align with vllm-ascend
TallMessiWu Apr 1, 2026
6026a18
🐛 fix(quantization/modelslim): resolve circular import in schemes/__i…
TallMessiWu Apr 2, 2026
3025e2d
:bug: fix(quantization/modelslim): fix no scheme found error.
TallMessiWu Apr 2, 2026
da92418
:bug: fix(llm/mxfp8): fix meaningless output issue
TallMessiWu Apr 3, 2026
29c04bc
:bug: fix(llm/mxfp8): fix meaningless output issue
TallMessiWu Apr 7, 2026
80d862c
:twisted_rightwards_arrows: chore(merge): sync upstream/main, drop di…
TallMessiWu May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from typing import TYPE_CHECKING, Optional

import torch
import torch_npu
from torch.nn.parameter import Parameter

from sglang.srt.hardware_backend.npu.utils import npu_format_cast
from sglang.srt.layers.quantization.base_config import LinearMethodBase

if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig

MXFP8_BLOCK_SIZE = 32
_FLOAT8_E8M0FNU_DTYPE = getattr(torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None))


class _NPULinearMethodBase(LinearMethodBase):

Expand Down Expand Up @@ -111,6 +116,101 @@ def apply(
)


class NPUMXFP8LinearMethod(_NPULinearMethodBase):
"""Ascend NPU MXFP8 linear method for LLM (SRT) models.

Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time.
Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32).
"""

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.parameter import ModelWeightParameter

output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

# Load weights in original dtype; quantise later in process_weights_after_loading
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)

# Move weight to NPU if needed (cpu offload may have moved it back to CPU)
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")

# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
# Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose)
layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale.transpose(0, 1).contiguous(), requires_grad=False)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype not in (torch.float16, torch.bfloat16):
x = x.to(torch.bfloat16)
original_dtype = torch.bfloat16

# Flatten to 2D [tokens, hidden] for npu_dynamic_mx_quant
input_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

# Dynamic MXFP8 activation quantisation
qx, input_scale = torch_npu.npu_dynamic_mx_quant(
x_2d, dst_type=torch_npu.float8_e4m3fn
)

# MXFP8 matmul (weight & scale already transposed at load time)
output = torch_npu.npu_quant_matmul(
qx,
layer.weight,
layer.weight_scale_inv,
scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=input_scale,
pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)

# Restore original shape (replace last dim with output features)
output_shape = list(input_shape[:-1]) + [output.shape[-1]]
return output.reshape(output_shape)


class NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase):

def process_weights_after_loading(self, layer):
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]

def get_min_capability(self) -> int:
if is_npu():
return 0 # NPU bypasses CUDA capability checks
if _is_musa:
return 31

Expand Down Expand Up @@ -258,6 +260,12 @@ def get_quant_method(
prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if is_npu() and self.use_mxfp8:
from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import (
NPUMXFP8LinearMethod,
)

return NPUMXFP8LinearMethod(self)
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QuantizationConfig,
)
from sglang.srt.layers.quantization.modelslim.schemes import (
ModelSlimMXFP8Scheme,
ModelSlimW4A4Int4,
ModelSlimW4A4Int4MoE,
ModelSlimW4A8Int8MoE,
Expand Down Expand Up @@ -180,6 +181,7 @@ def get_linear_scheme(
("W4A4_DYNAMIC", ModelSlimW4A4Int4),
("W8A8", ModelSlimW8A8Int8),
("W8A8_DYNAMIC", ModelSlimW8A8Int8),
("W8A8_MXFP8", ModelSlimMXFP8Scheme),
]

quant_schemes = [self.quant_description.get(prefix + ".weight", "")]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# SPDX-License-Identifier: Apache-2.0

# NOTE: Import order is critical to avoid circular dependency.
# modelslim_mxfp8 imports ModelSlimLinearScheme from this package,
# so the base class must be imported first.
# isort: off
from .modelslim_scheme import ModelSlimLinearScheme, ModelSlimMoEScheme
from .modelslim_mxfp8 import ModelSlimMXFP8Scheme

# isort: on
from .modelslim_w4a4_int4 import ModelSlimW4A4Int4
from .modelslim_w4a4_int4_moe import ModelSlimW4A4Int4MoE
from .modelslim_w4a8_int8_moe import ModelSlimW4A8Int8MoE
Expand All @@ -10,6 +17,7 @@
__all__ = [
"ModelSlimLinearScheme",
"ModelSlimMoEScheme",
"ModelSlimMXFP8Scheme",
"ModelSlimW8A8Int8",
"ModelSlimW4A4Int4",
"ModelSlimW4A4Int4MoE",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU (SRT).

Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights,
uint8 scales) and runs MXFP8 matmul at inference.
"""

from typing import Dict, List, Optional

import torch
import torch_npu

from sglang.srt.layers.parameter import GroupQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme

MXFP8_BLOCK_SIZE = 32
_FLOAT8_E8M0FNU_DTYPE = getattr(
torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)
)


class ModelSlimMXFP8Scheme(ModelSlimLinearScheme):

def __init__(
self,
quant_config: Optional[Dict[str, any]] = None,
prefix: Optional[str] = None,
):
# quant_config / prefix are accepted to match the linear-scheme
# dispatch signature used by ModelSlimConfig.get_linear_scheme;
# MXFP8 needs no per-layer config beyond what create_weights derives.
del quant_config, prefix

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
output_size_per_partition = sum(output_partition_sizes)

# msmodelslim exports weight as float8_e4m3fn, shape [out, in]
weight = ModelWeightParameter(
data=torch.empty(
(output_size_per_partition, input_size_per_partition),
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# msmodelslim exports weight_scale as uint8, shape [out, in/32].
# NOTE: Named "weight_scale" (not "weight_scale_inv") to match the
# checkpoint key exported by msmodelslim.
scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
(output_size_per_partition, scale_dim),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer: torch.nn.Module):
# Pre-transpose weight and scale to [in, out] for npu_quant_matmul.
# Use .data assignment without .contiguous() to preserve the transpose
# view strides — npu_quant_matmul reads strides correctly and calling
# .contiguous() would reorder data, breaking the block-scale mapping.
n_dim, k_dim = layer.weight_scale.data.shape
layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2)
layer.weight.data = layer.weight.data.transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype not in (torch.float16, torch.bfloat16):
x = x.to(torch.bfloat16)
original_dtype = torch.bfloat16

# npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size]
input_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

# Dynamic MXFP8 activation quantisation
qx, input_scale = torch_npu.npu_dynamic_mx_quant(
x_2d, dst_type=torch_npu.float8_e4m3fn
)

# MXFP8 matmul (weight & scale already transposed at load time)
output = torch_npu.npu_quant_matmul(
qx,
layer.weight,
layer.weight_scale,
scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=input_scale,
pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)

# Restore original shape (replace last dim with output features)
output_shape = list(input_shape[:-1]) + [output.shape[-1]]
return output.reshape(output_shape)
11 changes: 9 additions & 2 deletions python/sglang/srt/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@

if _is_npu:
import torch_npu
from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa

try:
from sgl_kernel_npu.norm.fused_rope_qk_mqa import fused_rope_qk_mqa
except ImportError:
fused_rope_qk_mqa = None

if _is_hip:
from sglang.srt.layers.attention.utils import (
Expand Down Expand Up @@ -267,7 +271,10 @@ def forward_npu(
else:
cos_sin = self.cos_sin_cache.index_select(0, positions)

if query.shape[0] * query.shape[1] < 65535:
if (
fused_rope_qk_mqa is not None
and query.shape[0] * query.shape[1] < 65535
):
return fused_rope_qk_mqa(
query,
key,
Expand Down
Loading