Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ef874c0
✨ feat(npu): add online MXFP8 quantization support for Ascend NPU (Pa…
TallMessiWu Mar 18, 2026
d2d19c6
✨ feat(diffusion): add online MXFP8 quantization support for Wan2.2 o…
TallMessiWu Mar 18, 2026
c838ade
:bug: fix(diffusion): fix npu method call error
TallMessiWu Mar 19, 2026
be3b684
:bug: fix(diffusion): fix MXFP8 quantization scale dimension mismatch…
TallMessiWu Mar 19, 2026
fd79b23
:recycle: refactor(mxfp8): split linear method into config and NPU la…
TallMessiWu Mar 20, 2026
df61b29
:twisted_rightwards_arrows: merge: sync from upstream
TallMessiWu Mar 20, 2026
490ad0b
:sparkles: feat(diffusion): add offline MXFP8 pre-quantized weight su…
TallMessiWu Mar 20, 2026
cc80690
:bug: fix(diffusion): correct MXFP8 weight dtype and scale shape
TallMessiWu Mar 23, 2026
b9aa785
✨ feat(wan22): Redesigned the wan_repack tool. Now support one-click …
TallMessiWu Mar 24, 2026
22bee9e
:recycle: refactor(mxfp8): hoist imports and replace print with logger
TallMessiWu Mar 24, 2026
a29bb3d
:pencil2: fix(diffusion/mxfp8): address review comments on ModelSlimM…
TallMessiWu Mar 25, 2026
3bbf703
:twisted_rightwards_arrows: chore(merge): sync upstream/main, keep MX…
TallMessiWu Mar 25, 2026
250fe65
:adhesive_bandage: fix(diffusion): register --quantization CLI arg to…
TallMessiWu Mar 25, 2026
e146b03
:bug: fix(mxfp8_npu): move weight to current NPU device before quanti…
TallMessiWu Mar 25, 2026
711bb8b
:rewind: revert(llm): remove LLM MXFP8 online quantization (Path B) f…
TallMessiWu Mar 25, 2026
1604d4e
:twisted_rightwards_arrows: chore(merge): sync upstream/main into junlin
TallMessiWu Mar 31, 2026
1101cf5
:adhesive_bandage: fix(loader): preserve --quantization flag priority…
TallMessiWu Mar 31, 2026
553a82c
Merge branch 'main' into junlin
ping1jing2 Apr 1, 2026
6bc42f9
:bug: fix(diffusion/mxfp8): fix torch_npu import error in non-npu env…
TallMessiWu Apr 2, 2026
4dc2135
Merge branch 'main' into junlin
ping1jing2 Apr 3, 2026
92e2939
Merge branch 'main' into junlin
ping1jing2 Apr 7, 2026
0617ae8
:memo: docs(quant/mxfp8): update docs
TallMessiWu Apr 8, 2026
8f54ac3
:twisted_rightwards_arrows: merge(junlin): sync upstream/main, preser…
TallMessiWu Apr 25, 2026
a0ee0be
:bug: fix(ci): fix Windows encoding and path separator bugs in pre-co…
TallMessiWu Apr 27, 2026
f92f820
:twisted_rightwards_arrows: Merge remote-tracking branch 'upstream/ma…
TallMessiWu Apr 27, 2026
e1dd1d3
:green_heart: fix(docs): fix broken relative path to ascend_npu_quant…
TallMessiWu Apr 27, 2026
1787605
Merge branch 'main' into junlin
ping1jing2 Apr 29, 2026
b3cd75b
:green_heart: fix(test): add missing quantization attr to _make_serve…
TallMessiWu Apr 29, 2026
4f79522
Merge branch 'main' into junlin
ping1jing2 Apr 30, 2026
d87905a
Merge branch 'main' into junlin
ping1jing2 May 1, 2026
762f21f
docs: sync LMSYS SGLang blog cards
github-actions[bot] May 1, 2026
ba014f2
Merge branch 'main' into junlin
ping1jing2 May 2, 2026
9be9672
Merge branch 'main' into junlin
ping1jing2 May 4, 2026
d44c6d2
:rewind: revert: drop auto-generated LMSYS blog sync from PR diff
TallMessiWu May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions docs/advanced_features/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ The following table summarizes quantization method support across NVIDIA and AMD
| `bitsandbytes` | Yes | Experimental | No | Depends on bitsandbytes ROCm support |
| `torchao` (`int4wo`, etc.) | Yes | Partial | No | `int4wo` not supported on AMD; other methods may work |
| `modelslim` | No | No | Yes | Ascend quantization; Uses CANN kernels |
| `mxfp8` (diffusion) | No | No | Yes (A2/A3) | Ascend NPU only; online MXFP8 quantization for diffusion models (e.g., Wan2.2); requires CANN ≥ 8.0.RC3 |

On AMD, several of these methods use [Aiter](https://github.com/ROCm/aiter) for acceleration -- set `SGLANG_USE_AITER=1` where noted. See [AMD GPU setup](../platforms/amd_gpu.md) for installation and configuration details.

Expand Down Expand Up @@ -590,6 +591,36 @@ SGLang running on AMD GPUs (CDNA3 or CDNA4 architecture) supports the quantizati

Other layers (e.g. projections in the attention layers) have their weights quantized online to float8 directly.

## Diffusion Model Quantization on Ascend NPU

SGLang-Diffusion supports MXFP8 quantization for diffusion models (such as Wan2.2) on Ascend A5 NPUs, in both online and offline (ModelSlim) modes. This is separate from the LLM serving path and uses the `sglang serve` / `sglang generate` CLI.

**Requirements:** Ascend A5, CANN ≥ 8.0.RC3

### Online MXFP8

Pass `--quantization mxfp8` to dynamically quantize FP16/BF16 transformer weights to MXFP8 at load time:

```bash
sglang serve \
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--quantization mxfp8 \
--num-gpus 4
```

### Offline MXFP8 (ModelSlim)

Pre-quantize with [msModelSlim](https://gitcode.com/Ascend/msmodelslim) and load the checkpoint directly — the quantization scheme is auto-detected from `quant_model_description.json`:

```bash
sglang generate \
--model-path /path/to/wan2_2_mxfp8_diffusers \
--prompt "a beautiful sunset" \
--save-output
```

For the full quantization + format conversion workflow and a complete list of supported schemes, see [Diffusion Quantization on Ascend NPU](../platforms/ascend/ascend_npu_quantization.md#diffusion-model-quantization-on-ascend-npu) and [SGLang-Diffusion Quantization](../diffusion/quantization.md#modelslim).

## Reference

- [GPTQModel](https://github.com/ModelCloud/GPTQModel)
Expand Down
2 changes: 1 addition & 1 deletion docs/diffusion/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,4 @@ MindStudio-ModelSlim (msModelSlim) is a model offline quantization compression t
- [x] ```W4A4_DYNAMIC``` linear with online quantization of activations
- [x] ```W8A8``` linear with offline quantization of activations
- [x] ```W8A8_DYNAMIC``` linear with online quantization of activations
- [ ] ```mxfp8``` linear in progress
- [x] ```mxfp8``` linear with online/offline MXFP8 quantization (Ascend A5, CANN ≥ 8.0.RC3; see [Ascend NPU quantization](../platforms/ascend/ascend_npu_quantization.md#diffusion-model-quantization-on-ascend-npu))
98 changes: 88 additions & 10 deletions docs/platforms/ascend/ascend_npu_quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ To load already quantized models, simply load the model weights and config. Agai
SGLang support **mix-bits** quantization (independently defines and loads each layer depending on the type of quantification specified in the `quant_model_description'.json`). [Advanced mix-bits for MoE](https://github.com/sgl-project/sglang/pull/17361) in progress, will add independent quantization determination for the w13 (up-gate) and w2 (down) layers.

[ModelSlim on Ascend support](https://github.com/sgl-project/sglang/pull/14504)
| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | Diffusion models |
|-----------------------------------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:------------------------------------------:|:------------------------------------------:|
| W4A4 dynamic | Linear | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: green;">√</span>** |
| W8A8 static | Linear | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: green;">√</span>** |
| W8A8 dynamic | Linear | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: green;">√</span>** |
| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | Linear | **<span style="color: red;">x</span>** | **<span style="color: red;">x</span>** | **<span style="color: blue;">WIP</span>** | **<span style="color: blue;">WIP</span>** |
| W4A4 dynamic | MoE | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: red;">x</span>** |
| W4A8 dynamic | MoE | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: red;">x</span>** |
| W8A8 dynamic | MoE | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: red;">x</span>** |
| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | MoE | **<span style="color: red;">x</span>** | **<span style="color: red;">x</span>** | **<span style="color: blue;">WIP</span>** | **<span style="color: red;">x</span>** |
| Quantization scheme | `quant_type` in JSON | Scheme class | Layer type | A2 Supported | A3 Supported | A5 Supported | Diffusion models |
|-----------------------------------------------------------|----------------------|--------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:------------------------------------------:|:------------------------------------------:|
| W4A4 dynamic | `W4A4_DYNAMIC` | `ModelSlimW4A4Int4` | Linear | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: green;">√</span>** |
| W8A8 static | `W8A8` | `ModelSlimW8A8Int8` | Linear | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: green;">√</span>** |
| W8A8 dynamic | `W8A8_DYNAMIC` | `ModelSlimW8A8Int8` | Linear | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: green;">√</span>** |
| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | `W8A8_MXFP8` | `ModelSlimMXFP8Scheme` | Linear | **<span style="color: red;">x</span>** | **<span style="color: red;">x</span>** | **<span style="color: blue;">WIP</span>** | **<span style="color: green;"></span>** (A5) |
| W4A4 dynamic | `W4A4_DYNAMIC` | `ModelSlimW4A4Int4` | MoE | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: red;">x</span>** |
| W4A8 dynamic | `W4A8_DYNAMIC` | `ModelSlimW4A8Int8MoE` | MoE | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: red;">x</span>** |
| W8A8 dynamic | `W8A8_DYNAMIC` | `ModelSlimW8A8Int8` | MoE | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | **<span style="color: red;">x</span>** |
| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | `W8A8_MXFP8` | `ModelSlimMXFP8Scheme` | MoE | **<span style="color: red;">x</span>** | **<span style="color: red;">x</span>** | **<span style="color: blue;">WIP</span>** | **<span style="color: red;">x</span>** |

[AWQ on Ascend support](https://github.com/sgl-project/sglang/pull/10158):
| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported |
Expand Down Expand Up @@ -54,3 +54,81 @@ Compressed-tensors (LLM Compressor) on Ascend support:
| [GGUF (all types)](https://github.com/sgl-project/sglang/pull/17883) | MoE | **<span style="color: green;">√</span>** | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** |

> Note: On Ascend, GGUF weights are pre-dequantized to FP16/BF16 during model loading to ensure optimal inference performance. This enables support for all GGUF quantization types (Q2_K, Q4_K_M, IQ4_XS, etc.) while maintaining high inference speed.

in progress

## Diffusion Model Quantization on Ascend NPU

SGLang-Diffusion supports MXFP8 online and offline quantization for diffusion models (such as Wan2.2) on Ascend NPUs. MXFP8 requires A5; the ModelSlim W8A8/W4A4 schemes work on A2/A3.

**Requirements for MXFP8:** CANN ≥ 8.0.RC3, Ascend A5

| Quantization method | `quant_type` in JSON | Scheme class | Mode | A2/A3 Supported | A5 Supported | Trigger |
|---------------------|-----------------------|-------------------------------|---------|:--------------------------------------------:|:----------------------------------------:|---------------------------------------------------|
| MXFP8 (W8A8) | — | `MXFP8Config` | Online | **<span style="color: red;">x</span>** | **<span style="color: green;">√</span>** | `--quantization mxfp8` |
| MXFP8 (W8A8) | `W8A8_MXFP8` | `ModelSlimMXFP8Scheme` | Offline | **<span style="color: red;">x</span>** | **<span style="color: green;">√</span>** | auto-detected from `quant_model_description.json` |
| W8A8 static | `W8A8` | `ModelSlimW8A8Int8` | Offline | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | auto-detected from `quant_model_description.json` |
| W8A8 dynamic | `W8A8_DYNAMIC` | `ModelSlimW8A8Int8` | Offline | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | auto-detected from `quant_model_description.json` |
| W4A4 dynamic | `W4A4_DYNAMIC` | `ModelSlimW4A4Int4` | Offline | **<span style="color: green;">√</span>** | **<span style="color: yellow;">TBD</span>** | auto-detected from `quant_model_description.json` |

### Online MXFP8 Quantization

Online quantization dynamically quantizes FP16/BF16 weights to MXFP8 at load time using `npu_dynamic_mx_quant` + `npu_quant_matmul` CANN kernels. Pass `--quantization mxfp8` to override auto-detection.

```bash
# Start the diffusion server with online MXFP8 quantization
sglang serve \
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--quantization mxfp8 \
--num-gpus 4
```

```bash
# One-shot generation
sglang generate \
--model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \
--quantization mxfp8 \
--prompt "a beautiful sunset over the mountains" \
--save-output
```

### Offline MXFP8 Quantization (ModelSlim)

For offline quantization, pre-quantize the model with msModelSlim and load the resulting checkpoint. The quantization scheme is auto-detected from `quant_model_description.json`, so no extra `--quantization` flag is needed.

**Step 1: Quantize with msModelSlim**

```bash
msmodelslim quant \
--model_path /path/to/wan2_2_float_weights \
--save_path /path/to/wan2_2_mxfp8_weights \
--device npu \
--model_type Wan2_2 \
--quant_type mxfp8 \
--trust_remote_code True
```

> Note: SGLang does not support quantized embeddings; disable embedding quantization when using msmodelslim.

**Step 2: Convert to Diffusers format**

msModelSlim saves quantized Wan2.2 weights in the original Wan format. Convert to Diffusers format using the provided repack script:

```bash
python python/sglang/multimodal_gen/tools/wan_repack.py \
--input-path /path/to/wan2_2_mxfp8_weights \
--output-path /path/to/wan2_2_mxfp8_diffusers
```

Then copy all files from the original Diffusers checkpoint (except the `transformer`/`transformer_2` folders) into the output directory.

**Step 3: Run inference**

```bash
sglang generate \
--model-path /path/to/wan2_2_mxfp8_diffusers \
--prompt "a beautiful sunset over the mountains" \
--save-output
```

For pre-quantized checkpoints available on ModelScope, see [modelscope/Eco-Tech](https://modelscope.cn/models/Eco-Tech).
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
ModelOptFp8Config,
)
from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig
from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config
Comment thread
TallMessiWu marked this conversation as resolved.

QuantizationMethods = Literal[
"fp8", "modelopt", "modelopt_fp8", "modelopt_fp4", "modelslim"
"fp8", "modelopt", "modelopt_fp8", "modelopt_fp4", "modelslim", "mxfp8"
]

QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
Expand All @@ -28,6 +29,7 @@
"modelopt_fp4": ModelOptFp4Config,
"modelslim": ModelSlimConfig,
"fp8": Fp8Config,
"mxfp8": MXFP8Config,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def _get_scheme_from_parts(
return ModelSlimW4A4Int4(
quant_config=self.quant_description, prefix=layer_name
)
elif quant_type == "W8A8_MXFP8":
from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp8_scheme import (
ModelSlimMXFP8Scheme,
)

return ModelSlimMXFP8Scheme()
raise NotImplementedError("No modelslim compatible scheme was found.")

def get_scheme(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU.

Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights,
uint8 scales) and runs MXFP8 matmul at inference.
"""

from typing import List, Optional

import torch

from sglang.multimodal_gen.runtime.platforms import current_platform

_is_npu = current_platform.is_npu()

if _is_npu:
import torch_npu

from sglang.multimodal_gen.runtime.models.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
)
from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme

MXFP8_BLOCK_SIZE = 32


class ModelSlimMXFP8Scheme(ModelSlimLinearScheme):

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
output_size_per_partition = sum(output_partition_sizes)

# msmodelslim exports weight as float8_e4m3fn, shape [out, in]
weight = ModelWeightParameter(
data=torch.empty(
(output_size_per_partition, input_size_per_partition),
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# msmodelslim exports weight_scale as uint8, shape [out, in/32].
# NOTE: This parameter is intentionally named "weight_scale" (not
# "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader
# matches parameter names to checkpoint keys, and msmodelslim checkpoints
# store this tensor under the key "<layer>.weight_scale".
scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
(output_size_per_partition, scale_dim),
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
Comment thread
TallMessiWu marked this conversation as resolved.

def process_weights_after_loading(self, layer: torch.nn.Module):
# weight is already float8_e4m3fn, no cast needed
weight = layer.weight.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False)

# Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
weight_scale = layer.weight_scale.data
weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

original_dtype = x.dtype
if original_dtype not in (torch.float16, torch.bfloat16):
# npu_dynamic_mx_quant only accepts fp16/bf16 activations
x = x.to(torch.bfloat16)
original_dtype = torch.bfloat16

# npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size].
# Diffusion transformer inputs are typically 3D [batch, seq, hidden] or
# higher. Flattening to 2D merges all leading dimensions into a single
# token axis so the NPU kernel can compute per-token MXFP8 scales, then
# we restore the original shape from the output.
input_shape = x.shape
x_2d = x.reshape(-1, x.shape[-1])

# Dynamic MXFP8 activation quantisation
qx, input_scale = torch_npu.npu_dynamic_mx_quant(
x_2d, dst_type=torch_npu.float8_e4m3fn
)

# MXFP8 matmul
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)

# Restore original shape
output_shape = list(input_shape[:-1]) + [output.shape[-1]]
output = output.reshape(output_shape)

return output
Loading
Loading