Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
ef874c0
✨ feat(npu): add online MXFP8 quantization support for Ascend NPU (Pa…
TallMessiWu Mar 18, 2026
d2d19c6
✨ feat(diffusion): add online MXFP8 quantization support for Wan2.2 o…
TallMessiWu Mar 18, 2026
c838ade
:bug: fix(diffusion): fix npu method call error
TallMessiWu Mar 19, 2026
be3b684
:bug: fix(diffusion): fix MXFP8 quantization scale dimension mismatch…
TallMessiWu Mar 19, 2026
fd79b23
:recycle: refactor(mxfp8): split linear method into config and NPU la…
TallMessiWu Mar 20, 2026
df61b29
:twisted_rightwards_arrows: merge: sync from upstream
TallMessiWu Mar 20, 2026
490ad0b
:sparkles: feat(diffusion): add offline MXFP8 pre-quantized weight su…
TallMessiWu Mar 20, 2026
cc80690
:bug: fix(diffusion): correct MXFP8 weight dtype and scale shape
TallMessiWu Mar 23, 2026
b9aa785
✨ feat(wan22): Redesigned the wan_repack tool. Now support one-click …
TallMessiWu Mar 24, 2026
22bee9e
:recycle: refactor(mxfp8): hoist imports and replace print with logger
TallMessiWu Mar 24, 2026
a29bb3d
:pencil2: fix(diffusion/mxfp8): address review comments on ModelSlimM…
TallMessiWu Mar 25, 2026
3bbf703
:twisted_rightwards_arrows: chore(merge): sync upstream/main, keep MX…
TallMessiWu Mar 25, 2026
250fe65
:adhesive_bandage: fix(diffusion): register --quantization CLI arg to…
TallMessiWu Mar 25, 2026
e146b03
:bug: fix(mxfp8_npu): move weight to current NPU device before quanti…
TallMessiWu Mar 25, 2026
711bb8b
:rewind: revert(llm): remove LLM MXFP8 online quantization (Path B) f…
TallMessiWu Mar 25, 2026
1604d4e
:twisted_rightwards_arrows: chore(merge): sync upstream/main into junlin
TallMessiWu Mar 31, 2026
1101cf5
:adhesive_bandage: fix(loader): preserve --quantization flag priority…
TallMessiWu Mar 31, 2026
f1c652b
:sparkles: feat(npu/mxfp8): add W8A8 MXFP8 LLM support on Ascend NPU
TallMessiWu Apr 1, 2026
97c45b6
:recycle: refactor(npu/mxfp8): refactor code to align with vllm-ascend
TallMessiWu Apr 1, 2026
6026a18
🐛 fix(quantization/modelslim): resolve circular import in schemes/__i…
TallMessiWu Apr 2, 2026
3025e2d
:bug: fix(quantization/modelslim): fix no scheme found error.
TallMessiWu Apr 2, 2026
da92418
:bug: fix(llm/mxfp8): fix meaningless output issue
TallMessiWu Apr 3, 2026
29c04bc
:bug: fix(llm/mxfp8): fix meaningless output issue
TallMessiWu Apr 7, 2026
6693352
:sparkles: feat(npu-quant): add MXFP4 W4A8 online quantization for Qw…
TallMessiWu Apr 8, 2026
e9dec3c
:sparkles: feat(npu/w4a8): add W4A8 offline Dense scheme and guards
TallMessiWu Apr 16, 2026
91e9bab
:sparkles: feat(npu): add W4A4 single-level MXFP4 online quantization
TallMessiWu Apr 10, 2026
401bec1
:fire: fix(w4a4): remove W4A8_MXFP placeholder; add MoE warning to W4…
TallMessiWu Apr 17, 2026
f5d6206
:sparkles: feat(modelslim): add offline W4A4 MXFP4 scheme (W4A4_MXFP4)
TallMessiWu Apr 17, 2026
066becd
:recycle: refactor(quant): rename mxfp4_npu to mxfp4_w4a8_npu and add…
TallMessiWu Apr 24, 2026
93d542a
:sparkles: feat(modelslim): add W4A8_MXFP offline scheme for LLM
TallMessiWu Apr 24, 2026
456bd14
:bug: fix(modelslim): fix W4A8_MXFP weight dtype to float8_e4m3fn
TallMessiWu Apr 24, 2026
df88902
:twisted_rightwards_arrows: merge(npu/quant): merge w4a8 updates into…
TallMessiWu Apr 24, 2026
5a9fbd6
:bug: fix(npu/quant): align W4A4 MXFP4 with vllm-ascend reference
TallMessiWu Apr 25, 2026
f138c53
:bug: fix(modelslim): align offline MXFP4 W4A4 scheme with vllm-ascend
TallMessiWu Apr 27, 2026
48a2e62
:bug: fix(modelslim): reshape weight_scale to 3D before transpose for…
TallMessiWu Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
ModelOptFp4Config,
)
from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig
from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config

QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim"]
QuantizationMethods = Literal["fp8", "modelopt_fp4", "modelslim", "mxfp8"]

QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))

Expand All @@ -20,6 +21,7 @@
"modelopt_fp4": ModelOptFp4Config,
"modelslim": ModelSlimConfig,
"fp8": Fp8Config,
"mxfp8": MXFP8Config,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def _get_scheme_from_parts(
return ModelSlimW4A4Int4(
quant_config=self.quant_description, prefix=layer_name
)
elif quant_type == "W8A8_MXFP8":
from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp8_scheme import (
ModelSlimMXFP8Scheme,
)

return ModelSlimMXFP8Scheme()
raise NotImplementedError("No modelslim compatible scheme was found.")

def get_scheme(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU.

Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights,
uint8 scales) and runs MXFP8 matmul at inference.
"""

from typing import List, Optional

import torch
import torch_npu

from sglang.multimodal_gen.runtime.models.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
)
from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme

MXFP8_BLOCK_SIZE = 32


class ModelSlimMXFP8Scheme(ModelSlimLinearScheme):

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
output_size_per_partition = sum(output_partition_sizes)

# msmodelslim exports weight as float8_e4m3fn, shape [out, in]
weight = ModelWeightParameter(
data=torch.empty(
(output_size_per_partition, input_size_per_partition),
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# msmodelslim exports weight_scale as uint8, shape [out, in/32].
# NOTE: This parameter is intentionally named "weight_scale" (not
# "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader
# matches parameter names to checkpoint keys, and msmodelslim checkpoints
# store this tensor under the key "<layer>.weight_scale".
scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
(output_size_per_partition, scale_dim),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer: torch.nn.Module):
# weight is already float8_e4m3fn, no cast needed
weight = layer.weight.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False)

# Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
weight_scale = layer.weight_scale.data
weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

original_dtype = x.dtype
if original_dtype not in (torch.float16, torch.bfloat16):
# npu_dynamic_mx_quant only accepts fp16/bf16 activations
x = x.to(torch.bfloat16)
original_dtype = torch.bfloat16

# npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size].
# Diffusion transformer inputs are typically 3D [batch, seq, hidden] or
# higher. Flattening to 2D merges all leading dimensions into a single
# token axis so the NPU kernel can compute per-token MXFP8 scales, then
# we restore the original shape from the output.
input_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

# Dynamic MXFP8 activation quantisation
qx, input_scale = torch_npu.npu_dynamic_mx_quant(
x_2d, dst_type=torch_npu.float8_e4m3fn
)

# MXFP8 matmul
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)

# Restore original shape
output_shape = list(input_shape[:-1]) + [output.shape[-1]]
output = output.reshape(output_shape)

return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""Online MXFP8 quantization for Diffusion models on Ascend NPU.

Provides ``MXFP8Config`` (registered as ``"mxfp8"``) and
``NPUMXFP8DiffusionLinearMethod`` which quantise FP16/BF16 weights to MXFP8
at load time and use ``npu_dynamic_mx_quant`` + ``npu_quant_matmul`` for
inference, mirroring the LLM-side ``NPUMXFP8LinearMethod``.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional

import torch
import torch_npu
from torch.nn.parameter import Parameter

from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase
from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter
from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger

logger = init_logger(__name__)

MXFP8_BLOCK_SIZE = 32


class MXFP8Config(QuantizationConfig):
"""Config for online MXFP8 quantization on Ascend NPU (Diffusion)."""

def __init__(self) -> None:
super().__init__()

@classmethod
def get_name(cls) -> str:
return "mxfp8"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.float16]

@classmethod
def get_min_capability(cls) -> int:
return 0 # NPU, not CUDA

@classmethod
def get_config_filenames(cls) -> List[str]:
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config":
return cls()

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return NPUMXFP8DiffusionLinearMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


class NPUMXFP8DiffusionLinearMethod(LinearMethodBase):
"""Ascend NPU MXFP8 linear method for Diffusion models.

Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time.
Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32).
"""

def __init__(self, quant_config: MXFP8Config):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")

layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype

# Load weights in original dtype; quantise later in process_weights_after_loading
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)

# Move weight to NPU if needed. We intentionally use a conditional
# move rather than an assert because `dit_cpu_offload` defaults to
# True in ServerArgs, which causes fsdp_load to move every parameter
# back to CPU after loading (even when the target device is NPU).
# npu_dynamic_mx_quant requires an NPU tensor, so we must transfer
# here. The quantized fp8 weights produced below will remain on NPU
# for inference; if the model still needs to be offloaded after
# quantization (e.g. very large model on a small NPU), a higher-level
# offload pass can move them back afterwards.
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")

# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
layer.weight = Parameter(qw, requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale, requires_grad=False)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype not in (torch.float16, torch.bfloat16):
x = x.to(torch.bfloat16)
original_dtype = torch.bfloat16

# Flatten to 2D [tokens, hidden] so npu_dynamic_mx_quant returns 3D scale
input_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

# Dynamic MXFP8 activation quantisation
qx, input_scale = torch_npu.npu_dynamic_mx_quant(
x_2d, dst_type=torch_npu.float8_e4m3fn
)

# MXFP8 matmul
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale_inv.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)

# Restore original shape (replace last dim with output features)
output_shape = list(input_shape[:-1]) + [output.shape[-1]]
output = output.reshape(output_shape)

return output
1 change: 1 addition & 0 deletions python/sglang/multimodal_gen/runtime/loader/fsdp_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def load_model_from_full_model_state_dict(
"bias",
"norm_q",
"norm_k",
"weight_scale",
]
for new_param_name in unused_keys:
meta_sharded_param = meta_sd.get(new_param_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,17 @@ def _resolve_quant_config(
) -> Optional[QuantizationConfig]:
"""
resolve quant config from checkpoints' metadata
priority: model config.json -> safetensors metadata -> format-specific fallback
priority: explicit --quantization flag -> model config.json -> safetensors metadata -> format-specific fallback
"""
# priority: explicit --quantization flag (e.g. mxfp8, mxfp4, modelslim)
if server_args.quantization is not None:
from sglang.multimodal_gen.runtime.layers.quantization import (
get_quantization_config,
)

quant_cls = get_quantization_config(server_args.quantization)
return quant_cls.from_config({})

quant_config = get_quant_config(hf_config, component_model_path)
if quant_config is None and server_args.transformer_weights_path:
for safetensors_file in safetensors_list:
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/multimodal_gen/runtime/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ class ServerArgs:

disable_autocast: bool | None = None

# Explicit quantization method override (e.g. "mxfp8", "fp8", "modelslim").
# When set, the transformer loader will use this instead of auto-detection.
quantization: str | None = None

# Quantization / Nunchaku SVDQuant configuration
nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field(
default_factory=NunchakuSVDQuantArgs, repr=False
Expand Down Expand Up @@ -771,6 +775,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Disable autocast for denoising loop and vae decoding in pipeline sampling",
)

parser.add_argument(
"--quantization",
type=str,
default=None,
help='Quantization method override (e.g. "mxfp8", "fp8", "modelslim"). '
"When set, the transformer loader will use this instead of auto-detection.",
)

# Nunchaku SVDQuant quantization parameters
NunchakuSVDQuantArgs.add_cli_args(parser)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@


def find_quant_modelslim_config(model_config, component_model_path):
# Try exact name first, then glob for variant filenames (e.g. after repack)
quant_config_file = Path(component_model_path, "quant_model_description.json")
if not quant_config_file.is_file():
candidates = sorted(
Path(component_model_path).glob("quant_model_description*.json")
)
quant_config_file = candidates[0] if candidates else None

quant_cfg = None
if quant_config_file.is_file():
if quant_config_file is not None and Path(quant_config_file).is_file():
with open(quant_config_file) as f:
quant_cfg = json.load(f)
# This field is required for flagless model loading but is not present in
Expand Down
Loading
Loading