diff --git a/docs/advanced_features/quantization.md b/docs/advanced_features/quantization.md index 2d244ca2e780..8e68d5d10b93 100644 --- a/docs/advanced_features/quantization.md +++ b/docs/advanced_features/quantization.md @@ -43,6 +43,7 @@ The following table summarizes quantization method support across NVIDIA and AMD | `bitsandbytes` | Yes | Experimental | No | Depends on bitsandbytes ROCm support | | `torchao` (`int4wo`, etc.) | Yes | Partial | No | `int4wo` not supported on AMD; other methods may work | | `modelslim` | No | No | Yes | Ascend quantization; Uses CANN kernels | +| `mxfp8` (diffusion) | No | No | Yes (A2/A3) | Ascend NPU only; online MXFP8 quantization for diffusion models (e.g., Wan2.2); requires CANN ≥ 8.0.RC3 | On AMD, several of these methods use [Aiter](https://github.com/ROCm/aiter) for acceleration -- set `SGLANG_USE_AITER=1` where noted. See [AMD GPU setup](../platforms/amd_gpu.md) for installation and configuration details. @@ -590,6 +591,36 @@ SGLang running on AMD GPUs (CDNA3 or CDNA4 architecture) supports the quantizati Other layers (e.g. projections in the attention layers) have their weights quantized online to float8 directly. +## Diffusion Model Quantization on Ascend NPU + +SGLang-Diffusion supports MXFP8 quantization for diffusion models (such as Wan2.2) on Ascend A5 NPUs, in both online and offline (ModelSlim) modes. This is separate from the LLM serving path and uses the `sglang serve` / `sglang generate` CLI. + +**Requirements:** Ascend A5, CANN ≥ 8.0.RC3 + +### Online MXFP8 + +Pass `--quantization mxfp8` to dynamically quantize FP16/BF16 transformer weights to MXFP8 at load time: + +```bash +sglang serve \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --quantization mxfp8 \ + --num-gpus 4 +``` + +### Offline MXFP8 (ModelSlim) + +Pre-quantize with [msModelSlim](https://gitcode.com/Ascend/msmodelslim) and load the checkpoint directly — the quantization scheme is auto-detected from `quant_model_description.json`: + +```bash +sglang generate \ + --model-path /path/to/wan2_2_mxfp8_diffusers \ + --prompt "a beautiful sunset" \ + --save-output +``` + +For the full quantization + format conversion workflow and a complete list of supported schemes, see [Diffusion Quantization on Ascend NPU](../platforms/ascend/ascend_npu_quantization.md#diffusion-model-quantization-on-ascend-npu) and [SGLang-Diffusion Quantization](../diffusion/quantization.md#modelslim). + ## Reference - [GPTQModel](https://github.com/ModelCloud/GPTQModel) diff --git a/docs/diffusion/quantization.md b/docs/diffusion/quantization.md index e9340c54bba2..541cd46c5356 100644 --- a/docs/diffusion/quantization.md +++ b/docs/diffusion/quantization.md @@ -374,4 +374,4 @@ MindStudio-ModelSlim (msModelSlim) is a model offline quantization compression t - [x] ```W4A4_DYNAMIC``` linear with online quantization of activations - [x] ```W8A8``` linear with offline quantization of activations - [x] ```W8A8_DYNAMIC``` linear with online quantization of activations - - [ ] ```mxfp8``` linear in progress + - [x] ```mxfp8``` linear with online/offline MXFP8 quantization (Ascend A5, CANN ≥ 8.0.RC3; see [Ascend NPU quantization](../platforms/ascend/ascend_npu_quantization.md#diffusion-model-quantization-on-ascend-npu)) diff --git a/docs/platforms/ascend/ascend_npu_quantization.md b/docs/platforms/ascend/ascend_npu_quantization.md index 2524d2e3c9c3..e60173850d82 100644 --- a/docs/platforms/ascend/ascend_npu_quantization.md +++ b/docs/platforms/ascend/ascend_npu_quantization.md @@ -5,16 +5,16 @@ To load already quantized models, simply load the model weights and config. Agai SGLang support **mix-bits** quantization (independently defines and loads each layer depending on the type of quantification specified in the `quant_model_description'.json`). [Advanced mix-bits for MoE](https://github.com/sgl-project/sglang/pull/17361) in progress, will add independent quantization determination for the w13 (up-gate) and w2 (down) layers. [ModelSlim on Ascend support](https://github.com/sgl-project/sglang/pull/14504) -| Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | Diffusion models | -|-----------------------------------------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:------------------------------------------:|:------------------------------------------:| -| W4A4 dynamic | Linear | **** | **** | **TBD** | **** | -| W8A8 static | Linear | **** | **** | **TBD** | **** | -| W8A8 dynamic | Linear | **** | **** | **TBD** | **** | -| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | Linear | **x** | **x** | **WIP** | **WIP** | -| W4A4 dynamic | MoE | **** | **** | **TBD** | **x** | -| W4A8 dynamic | MoE | **** | **** | **TBD** | **x** | -| W8A8 dynamic | MoE | **** | **** | **TBD** | **x** | -| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | MoE | **x** | **x** | **WIP** | **x** | +| Quantization scheme | `quant_type` in JSON | Scheme class | Layer type | A2 Supported | A3 Supported | A5 Supported | Diffusion models | +|-----------------------------------------------------------|----------------------|--------------------------|--------------------------|:----------------------------------------:|:----------------------------------------:|:------------------------------------------:|:------------------------------------------:| +| W4A4 dynamic | `W4A4_DYNAMIC` | `ModelSlimW4A4Int4` | Linear | **** | **** | **TBD** | **** | +| W8A8 static | `W8A8` | `ModelSlimW8A8Int8` | Linear | **** | **** | **TBD** | **** | +| W8A8 dynamic | `W8A8_DYNAMIC` | `ModelSlimW8A8Int8` | Linear | **** | **** | **TBD** | **** | +| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | `W8A8_MXFP8` | `ModelSlimMXFP8Scheme` | Linear | **x** | **x** | **WIP** | **** (A5) | +| W4A4 dynamic | `W4A4_DYNAMIC` | `ModelSlimW4A4Int4` | MoE | **** | **** | **TBD** | **x** | +| W4A8 dynamic | `W4A8_DYNAMIC` | `ModelSlimW4A8Int8MoE` | MoE | **** | **** | **TBD** | **x** | +| W8A8 dynamic | `W8A8_DYNAMIC` | `ModelSlimW8A8Int8` | MoE | **** | **** | **TBD** | **x** | +| [MXFP8](https://github.com/sgl-project/sglang/pull/20922) | `W8A8_MXFP8` | `ModelSlimMXFP8Scheme` | MoE | **x** | **x** | **WIP** | **x** | [AWQ on Ascend support](https://github.com/sgl-project/sglang/pull/10158): | Quantization scheme | Layer type | A2 Supported | A3 Supported | A5 Supported | @@ -54,3 +54,81 @@ Compressed-tensors (LLM Compressor) on Ascend support: | [GGUF (all types)](https://github.com/sgl-project/sglang/pull/17883) | MoE | **** | **** | **TBD** | > Note: On Ascend, GGUF weights are pre-dequantized to FP16/BF16 during model loading to ensure optimal inference performance. This enables support for all GGUF quantization types (Q2_K, Q4_K_M, IQ4_XS, etc.) while maintaining high inference speed. + +in progress + +## Diffusion Model Quantization on Ascend NPU + +SGLang-Diffusion supports MXFP8 online and offline quantization for diffusion models (such as Wan2.2) on Ascend NPUs. MXFP8 requires A5; the ModelSlim W8A8/W4A4 schemes work on A2/A3. + +**Requirements for MXFP8:** CANN ≥ 8.0.RC3, Ascend A5 + +| Quantization method | `quant_type` in JSON | Scheme class | Mode | A2/A3 Supported | A5 Supported | Trigger | +|---------------------|-----------------------|-------------------------------|---------|:--------------------------------------------:|:----------------------------------------:|---------------------------------------------------| +| MXFP8 (W8A8) | — | `MXFP8Config` | Online | **x** | **** | `--quantization mxfp8` | +| MXFP8 (W8A8) | `W8A8_MXFP8` | `ModelSlimMXFP8Scheme` | Offline | **x** | **** | auto-detected from `quant_model_description.json` | +| W8A8 static | `W8A8` | `ModelSlimW8A8Int8` | Offline | **** | **TBD** | auto-detected from `quant_model_description.json` | +| W8A8 dynamic | `W8A8_DYNAMIC` | `ModelSlimW8A8Int8` | Offline | **** | **TBD** | auto-detected from `quant_model_description.json` | +| W4A4 dynamic | `W4A4_DYNAMIC` | `ModelSlimW4A4Int4` | Offline | **** | **TBD** | auto-detected from `quant_model_description.json` | + +### Online MXFP8 Quantization + +Online quantization dynamically quantizes FP16/BF16 weights to MXFP8 at load time using `npu_dynamic_mx_quant` + `npu_quant_matmul` CANN kernels. Pass `--quantization mxfp8` to override auto-detection. + +```bash +# Start the diffusion server with online MXFP8 quantization +sglang serve \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --quantization mxfp8 \ + --num-gpus 4 +``` + +```bash +# One-shot generation +sglang generate \ + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers \ + --quantization mxfp8 \ + --prompt "a beautiful sunset over the mountains" \ + --save-output +``` + +### Offline MXFP8 Quantization (ModelSlim) + +For offline quantization, pre-quantize the model with msModelSlim and load the resulting checkpoint. The quantization scheme is auto-detected from `quant_model_description.json`, so no extra `--quantization` flag is needed. + +**Step 1: Quantize with msModelSlim** + +```bash +msmodelslim quant \ + --model_path /path/to/wan2_2_float_weights \ + --save_path /path/to/wan2_2_mxfp8_weights \ + --device npu \ + --model_type Wan2_2 \ + --quant_type mxfp8 \ + --trust_remote_code True +``` + +> Note: SGLang does not support quantized embeddings; disable embedding quantization when using msmodelslim. + +**Step 2: Convert to Diffusers format** + +msModelSlim saves quantized Wan2.2 weights in the original Wan format. Convert to Diffusers format using the provided repack script: + +```bash +python python/sglang/multimodal_gen/tools/wan_repack.py \ + --input-path /path/to/wan2_2_mxfp8_weights \ + --output-path /path/to/wan2_2_mxfp8_diffusers +``` + +Then copy all files from the original Diffusers checkpoint (except the `transformer`/`transformer_2` folders) into the output directory. + +**Step 3: Run inference** + +```bash +sglang generate \ + --model-path /path/to/wan2_2_mxfp8_diffusers \ + --prompt "a beautiful sunset over the mountains" \ + --save-output +``` + +For pre-quantized checkpoints available on ModelScope, see [modelscope/Eco-Tech](https://modelscope.cn/models/Eco-Tech). diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py index c4cdb687c928..ac9340697ba0 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -14,9 +14,10 @@ ModelOptFp8Config, ) from sglang.multimodal_gen.runtime.layers.quantization.modelslim import ModelSlimConfig +from sglang.multimodal_gen.runtime.layers.quantization.mxfp8_npu import MXFP8Config QuantizationMethods = Literal[ - "fp8", "modelopt", "modelopt_fp8", "modelopt_fp4", "modelslim" + "fp8", "modelopt", "modelopt_fp8", "modelopt_fp4", "modelslim", "mxfp8" ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -28,6 +29,7 @@ "modelopt_fp4": ModelOptFp4Config, "modelslim": ModelSlimConfig, "fp8": Fp8Config, + "mxfp8": MXFP8Config, } diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py index afb9a31e4db9..4a9b96f9c9c9 100644 --- a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim.py @@ -119,6 +119,12 @@ def _get_scheme_from_parts( return ModelSlimW4A4Int4( quant_config=self.quant_description, prefix=layer_name ) + elif quant_type == "W8A8_MXFP8": + from sglang.multimodal_gen.runtime.layers.quantization.modelslim_mxfp8_scheme import ( + ModelSlimMXFP8Scheme, + ) + + return ModelSlimMXFP8Scheme() raise NotImplementedError("No modelslim compatible scheme was found.") def get_scheme( diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py new file mode 100644 index 000000000000..1bc49779d081 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py @@ -0,0 +1,124 @@ +"""ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. + +Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, +uint8 scales) and runs MXFP8 matmul at inference. +""" + +from typing import List, Optional + +import torch + +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_npu = current_platform.is_npu() + +if _is_npu: + import torch_npu + +from sglang.multimodal_gen.runtime.models.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme + +MXFP8_BLOCK_SIZE = 32 + + +class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # msmodelslim exports weight as float8_e4m3fn, shape [out, in] + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # msmodelslim exports weight_scale as uint8, shape [out, in/32]. + # NOTE: This parameter is intentionally named "weight_scale" (not + # "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader + # matches parameter names to checkpoint keys, and msmodelslim checkpoints + # store this tensor under the key ".weight_scale". + scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + (output_size_per_partition, scale_dim), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + # weight is already float8_e4m3fn, no cast needed + weight = layer.weight.data + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + + # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2] + weight_scale = layer.weight_scale.data + weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + # npu_dynamic_mx_quant only accepts fp16/bf16 activations + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size]. + # Diffusion transformer inputs are typically 3D [batch, seq, hidden] or + # higher. Flattening to 2D merges all leading dimensions into a single + # token axis so the NPU kernel can compute per-token MXFP8 scales, then + # we restore the original shape from the output. + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py new file mode 100644 index 000000000000..17a4370cfdbf --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/mxfp8_npu.py @@ -0,0 +1,176 @@ +"""Online MXFP8 quantization for Diffusion models on Ascend NPU. + +Provides ``MXFP8Config`` (registered as ``"mxfp8"``) and +``NPUMXFP8DiffusionLinearMethod`` which quantise FP16/BF16 weights to MXFP8 +at load time and use ``npu_dynamic_mx_quant`` + ``npu_quant_matmul`` for +inference, mirroring the LLM-side ``NPUMXFP8LinearMethod``. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.platforms import current_platform + +_is_npu = current_platform.is_npu() + +if _is_npu: + import torch_npu + +from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase +from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +MXFP8_BLOCK_SIZE = 32 + + +class MXFP8Config(QuantizationConfig): + """Config for online MXFP8 quantization on Ascend NPU (Diffusion).""" + + def __init__(self) -> None: + super().__init__() + + @classmethod + def get_name(cls) -> str: + return "mxfp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 0 # NPU, not CUDA + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config": + return cls() + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + return NPUMXFP8DiffusionLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class NPUMXFP8DiffusionLinearMethod(LinearMethodBase): + """Ascend NPU MXFP8 linear method for Diffusion models. + + Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. + Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). + """ + + def __init__(self, quant_config: MXFP8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # Load weights in original dtype; quantise later in process_weights_after_loading + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + weight_fp = layer.weight.data + if weight_fp.dtype not in (torch.float16, torch.bfloat16): + weight_fp = weight_fp.to(torch.bfloat16) + + # Move weight to NPU if needed. We intentionally use a conditional + # move rather than an assert because `dit_cpu_offload` defaults to + # True in ServerArgs, which causes fsdp_load to move every parameter + # back to CPU after loading (even when the target device is NPU). + # npu_dynamic_mx_quant requires an NPU tensor, so we must transfer + # here. The quantized fp8 weights produced below will remain on NPU + # for inference; if the model still needs to be offloaded after + # quantization (e.g. very large model on a small NPU), a higher-level + # offload pass can move them back afterwards. + if not weight_fp.is_npu: + weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") + + # Online MXFP8 quantisation of weights (block_size=32) + qw, w_scale = torch_npu.npu_dynamic_mx_quant( + weight_fp, dst_type=torch_npu.float8_e4m3fn + ) + layer.weight = Parameter(qw, requires_grad=False) + layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + if original_dtype not in (torch.float16, torch.bfloat16): + x = x.to(torch.bfloat16) + original_dtype = torch.bfloat16 + + # Flatten to 2D [tokens, hidden] so npu_dynamic_mx_quant returns 3D scale + input_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + # Dynamic MXFP8 activation quantisation + qx, input_scale = torch_npu.npu_dynamic_mx_quant( + x_2d, dst_type=torch_npu.float8_e4m3fn + ) + + # MXFP8 matmul + output = torch_npu.npu_quant_matmul( + qx, + layer.weight.transpose(0, 1), + layer.weight_scale_inv.transpose(0, 1), + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=input_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias.to(torch.float32) if bias is not None else None, + output_dtype=original_dtype, + group_sizes=[1, 1, MXFP8_BLOCK_SIZE], + ) + + # Restore original shape (replace last dim with output features) + output_shape = list(input_shape[:-1]) + [output.shape[-1]] + output = output.reshape(output_shape) + + return output diff --git a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py index f9bdae93e729..f426bec1efd1 100644 --- a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py +++ b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -507,6 +507,7 @@ def load_model_from_full_model_state_dict( "bias", "norm_q", "norm_k", + "weight_scale", ] for new_param_name in unused_keys: meta_sharded_param = meta_sd.get(new_param_name) diff --git a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py index 2f2057b8dd65..c4f3c7623791 100644 --- a/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py +++ b/python/sglang/multimodal_gen/runtime/loader/transformer_load_utils.py @@ -477,8 +477,17 @@ def _resolve_quant_config( ) -> Optional[QuantizationConfig]: """ resolve quant config from checkpoints' metadata - priority: model config.json -> safetensors metadata -> format-specific fallback + priority: explicit --quantization flag -> model config.json -> safetensors metadata -> format-specific fallback """ + # priority: explicit --quantization flag (e.g. mxfp8, mxfp4, modelslim) + if server_args.quantization is not None: + from sglang.multimodal_gen.runtime.layers.quantization import ( + get_quantization_config, + ) + + quant_cls = get_quantization_config(server_args.quantization) + return quant_cls.from_config({}) + arch_config = server_args.pipeline_config.dit_config.arch_config param_names_mapping_dict = arch_config.param_names_mapping reverse_param_names_mapping_dict = getattr( diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py index 6aad57647229..be5098263cdc 100644 --- a/python/sglang/multimodal_gen/runtime/server_args.py +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -208,6 +208,10 @@ class ServerArgs(DisaggArgsMixin): disable_autocast: bool | None = None + # Explicit quantization method override (e.g. "mxfp8", "fp8", "modelslim"). + # When set, the transformer loader will use this instead of auto-detection. + quantization: str | None = None + # Quantization / Nunchaku SVDQuant configuration nunchaku_config: NunchakuSVDQuantArgs | NunchakuConfig | None = field( default_factory=NunchakuSVDQuantArgs, repr=False @@ -1006,6 +1010,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Disable autocast for denoising loop and vae decoding in pipeline sampling", ) + parser.add_argument( + "--quantization", + type=str, + default=None, + help='Quantization method override (e.g. "mxfp8", "fp8", "modelslim"). ' + "When set, the transformer loader will use this instead of auto-detection.", + ) + # Nunchaku SVDQuant quantization parameters NunchakuSVDQuantArgs.add_cli_args(parser) diff --git a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py index 033728f2d2fa..d90dc0f682d3 100644 --- a/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py +++ b/python/sglang/multimodal_gen/runtime/utils/quantization_utils.py @@ -97,9 +97,16 @@ def _load_quant_cls(quant_cfg: dict): def find_quant_modelslim_config(model_config, component_model_path): + # Try exact name first, then glob for variant filenames (e.g. after repack) quant_config_file = Path(component_model_path, "quant_model_description.json") + if not quant_config_file.is_file(): + candidates = sorted( + Path(component_model_path).glob("quant_model_description*.json") + ) + quant_config_file = candidates[0] if candidates else None + quant_cfg = None - if quant_config_file.is_file(): + if quant_config_file is not None and Path(quant_config_file).is_file(): with open(quant_config_file) as f: quant_cfg = json.load(f) # This field is required for flagless model loading but is not present in diff --git a/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py b/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py index dd0cd3685cc3..6d0ac4b9f28f 100644 --- a/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py +++ b/python/sglang/multimodal_gen/test/unit/test_transformer_quant.py @@ -84,6 +84,7 @@ def _make_server_args(self, **overrides): ), ), nunchaku_config=None, + quantization=None, tp_size=1, dit_cpu_offload=False, text_encoder_cpu_offload=False, diff --git a/python/sglang/multimodal_gen/tools/wan_repack.py b/python/sglang/multimodal_gen/tools/wan_repack.py index 2d7132747e7a..308b229d8593 100644 --- a/python/sglang/multimodal_gen/tools/wan_repack.py +++ b/python/sglang/multimodal_gen/tools/wan_repack.py @@ -1,115 +1,225 @@ -### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py - -import argparse -import json -import pathlib -from typing import Any, Dict, Tuple - -from safetensors.torch import load_file, save_file - -TRANSFORMER_KEYS_RENAME_DICT = { - "time_embedding.0": "condition_embedder.time_embedder.linear_1", - "time_embedding.2": "condition_embedder.time_embedder.linear_2", - "text_embedding.0": "condition_embedder.text_embedder.linear_1", - "text_embedding.2": "condition_embedder.text_embedder.linear_2", - "time_projection.1": "condition_embedder.time_proj", - "head.modulation": "scale_shift_table", - "head.head": "proj_out", - "modulation": "scale_shift_table", - "ffn.0": "ffn.net.0.proj", - "ffn.2": "ffn.net.2", - # Hack to swap the layer names - # The original model calls the norms in following order: norm1, norm3, norm2 - # We convert it to: norm1, norm2, norm3 - "norm2": "norm__placeholder", - "norm3": "norm2", - "norm__placeholder": "norm3", - # For the I2V model - "img_emb.proj.0": "condition_embedder.image_embedder.norm1", - "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", - "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", - "img_emb.proj.4": "condition_embedder.image_embedder.norm2", - # for the FLF2V model - "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", - # Add attention component mappings - "self_attn.q": "attn1.to_q", - "self_attn.k": "attn1.to_k", - "self_attn.v": "attn1.to_v", - "self_attn.o": "attn1.to_out.0", - "self_attn.norm_q": "attn1.norm_q", - "self_attn.norm_k": "attn1.norm_k", - "cross_attn.q": "attn2.to_q", - "cross_attn.k": "attn2.to_k", - "cross_attn.v": "attn2.to_v", - "cross_attn.o": "attn2.to_out.0", - "cross_attn.norm_q": "attn2.norm_q", - "cross_attn.norm_k": "attn2.norm_k", - "attn2.to_k_img": "attn2.add_k_proj", - "attn2.to_v_img": "attn2.add_v_proj", - "attn2.norm_k_img": "attn2.norm_added_k", -} - - -def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: - if model_type == "Wan-T2V-14B": - RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT - return RENAME_DICT - - -def update_dict_(dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: - dict[new_key] = dict.pop(old_key) - - -def load_sharded_safetensors(path: pathlib.Path): - file_path = path - state_dict = {} - state_dict.update(load_file(file_path)) - return state_dict - - -def convert_transformer(model_type: str, model_dir: str, output_dir: str): - pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) - RENAME_DICT = get_transformer_config(model_type) - - original_state_dict = load_sharded_safetensors( - pathlib.Path(model_dir, "*model*.safetensors") - ) - with open(pathlib.Path(model_dir, "*quant_model_description*.json")) as f: - original_quant_config = json.load(f) - - for key in list(original_state_dict.keys()): - new_key = key[:] - for replace_key, rename_key in RENAME_DICT.items(): - new_key = new_key.replace(replace_key, rename_key) - update_dict_(original_state_dict, key, new_key) - update_dict_(original_quant_config, key, new_key) - - save_file( - original_state_dict, - pathlib.Path(output_dir, "diffusion_pytorch_model.safetensors"), - ) - - with open(pathlib.Path(output_dir, "quant_model_description.json"), "w") as f: - json.dump(original_quant_config, f) - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--input-path", type=str, required=True) - parser.add_argument("--output-path", type=str, required=True) - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "high_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer"), - ) - convert_transformer( - "Wan-T2V-14B", - model_dir=pathlib.Path(args.input_path, "low_noise_model"), - output_dir=pathlib.Path(args.output_path, "transformer_2"), - ) +### Based on https://github.com/huggingface/diffusers/blob/main/scripts/convert_wan_to_diffusers.py + +import argparse +import json +import pathlib +import shutil +from typing import Any, Dict, List + +from safetensors.torch import load_file, save_file + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # for the FLF2V model + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", +} + +SUPPORTED_MODEL_TYPES = ["Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B", "Wan2.2-TI2V-5B"] + +# Cascade models have two transformers (high_noise + low_noise) +CASCADE_MODEL_TYPES = {"Wan2.2-T2V-A14B", "Wan2.2-I2V-A14B"} + + +def get_transformer_config(model_type: str) -> Dict[str, Any]: + if model_type in SUPPORTED_MODEL_TYPES: + return TRANSFORMER_KEYS_RENAME_DICT + else: + raise ValueError( + f"Unsupported model_type: {model_type}. Supported: {SUPPORTED_MODEL_TYPES}" + ) + + +def get_transformer_dirs(model_type: str) -> List[str]: + """Return the list of transformer directory names for a given model type.""" + if model_type in CASCADE_MODEL_TYPES: + return ["transformer", "transformer_2"] + return ["transformer"] + + +def get_quant_subpath( + model_type: str, quant_path: pathlib.Path, transformer_dir: str +) -> pathlib.Path: + """Return the quant weights subdirectory for a given transformer.""" + if model_type in CASCADE_MODEL_TYPES: + sub = ( + "high_noise_model" + if transformer_dir == "transformer" + else "low_noise_model" + ) + return quant_path / sub + return quant_path + + +def update_dict_(d: Dict[str, Any], old_key: str, new_key: str) -> None: + d[new_key] = d.pop(old_key) + + +def load_sharded_safetensors(directory: pathlib.Path, pattern: str) -> dict: + candidates = sorted(directory.glob(pattern)) + if not candidates: + raise FileNotFoundError(f"No file matching '{pattern}' found in {directory}") + if len(candidates) > 1: + raise FileNotFoundError( + f"Multiple files matching '{pattern}' found in {directory}: {candidates}" + ) + + state_dict = {} + state_dict.update(load_file(candidates[0])) + return state_dict + + +def convert_transformer( + model_type: str, model_dir: pathlib.Path, output_dir: pathlib.Path +) -> None: + """Convert a single quantized transformer directory into Diffusers format.""" + model_path = pathlib.Path(model_dir) + out_path = pathlib.Path(output_dir) + out_path.mkdir(parents=True, exist_ok=True) + RENAME_DICT = get_transformer_config(model_type) + + state_dict = load_sharded_safetensors(model_path, "quant_model_weight*.safetensors") + + json_candidates = sorted(model_path.glob("quant_model_description*.json")) + if not json_candidates: + raise FileNotFoundError( + f"No quant_model_description*.json found in {model_path}" + ) + with open(json_candidates[0]) as f: + quant_config = json.load(f) + + for key in list(state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + if new_key != key: + update_dict_(state_dict, key, new_key) + # The quant JSON only covers quantized layers, not all model keys + if key in quant_config: + update_dict_(quant_config, key, new_key) + + save_file(state_dict, out_path / "diffusion_pytorch_model.safetensors") + + with open(out_path / "quant_model_description.json", "w") as f: + json.dump(quant_config, f, indent=2) + + +def repack( + model_type: str, + original_model_path: pathlib.Path, + quant_path: pathlib.Path, + output_path: pathlib.Path, +) -> None: + """ + Full one-step repack workflow: + 1. Copy the original HF Diffusers model to output_path, excluding transformer dir(s). + 2. For each transformer: convert quant weights and copy config.json from original. + """ + transformer_dirs = get_transformer_dirs(model_type) + + # Step 1: Copy original model, skipping transformer dirs (they will be replaced) + logger.debug(f"Step 1: Copying original model to {output_path}") + logger.debug(f" (skipping: {transformer_dirs})") + shutil.copytree( + str(original_model_path), + str(output_path), + ignore=shutil.ignore_patterns(*transformer_dirs), + ) + + # Step 2+: Convert each transformer + for i, tdir in enumerate(transformer_dirs): + q_path = get_quant_subpath(model_type, quant_path, tdir) + out_tdir = output_path / tdir + logger.debug( + f"\nStep {i + 2}: Converting {tdir} (quant source: {q_path.name})..." + ) + convert_transformer(model_type, q_path, out_tdir) + + # Copy config.json from the original transformer dir + src_config = original_model_path / tdir / "config.json" + if src_config.is_file(): + shutil.copy2(str(src_config), str(out_tdir / "config.json")) + logger.debug(f" Copied config.json from original {tdir}/") + + logger.info(f"\nDone! Repacked model saved to: {output_path}") + + +def get_args(): + parser = argparse.ArgumentParser( + description="Repack msmodelslim quantized Wan2.2 weights into HF Diffusers format" + ) + parser.add_argument( + "--model-type", + type=str, + required=True, + choices=SUPPORTED_MODEL_TYPES, + help="Model type to convert", + ) + parser.add_argument( + "--original-model-path", + type=str, + required=True, + help="Path to the original HF Diffusers model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers)", + ) + parser.add_argument( + "--quant-path", + type=str, + required=True, + help="Path to msmodelslim quantized weights directory", + ) + parser.add_argument( + "--output-path", + type=str, + required=True, + help="Output path for the repacked model (e.g., /weights/Wan2.2-TI2V-5B-Diffusers-MXFP8)", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + repack( + model_type=args.model_type, + original_model_path=pathlib.Path(args.original_model_path), + quant_path=pathlib.Path(args.quant_path), + output_path=pathlib.Path(args.output_path), + ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index c85776d5bf2d..8056802b1969 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -64,10 +64,7 @@ requant_weight_ue8m0_inplace, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod -from sglang.srt.layers.quantization.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) +from sglang.srt.layers.quantization.marlin_utils_fp8 import prepare_fp8_layer_for_marlin from sglang.srt.layers.quantization.unquant import ( UnquantizedFusedMoEMethod, UnquantizedLinearMethod, @@ -707,7 +704,7 @@ def apply( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.use_marlin: - return apply_fp8_marlin_linear( + return torch.ops.sglang.apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, @@ -1077,15 +1074,23 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: w2_weight_scale, requires_grad=False ) layer.w2_input_scale = None - - if _use_aiter: + if _use_aiter: + # add this section for MI300 + # Pre-shuffle weights + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) + elif _use_aiter: # Pre-shuffle weights - t = shuffle_weight(layer.w13_weight, (16, 16)) - layer.w13_weight.copy_(t) - del t - t = shuffle_weight(layer.w2_weight, (16, 16)) - layer.w2_weight.copy_(t) - del t + layer.w13_weight.data = shuffle_weight( + layer.w13_weight.contiguous(), (16, 16) + ) + layer.w2_weight.data = shuffle_weight( + layer.w2_weight.contiguous(), (16, 16) + ) elif _is_cpu: assert ( _is_cpu_amx_available diff --git a/scripts/ci/check_registered_tests.py b/scripts/ci/check_registered_tests.py index 8f4a910edf29..3a9e9b87b242 100755 --- a/scripts/ci/check_registered_tests.py +++ b/scripts/ci/check_registered_tests.py @@ -26,7 +26,7 @@ def main() -> int: files = sorted( f for f in glob.glob("test/registered/**/*.py", recursive=True) - if not f.endswith("/conftest.py") and not f.endswith("/__init__.py") + if os.path.basename(f) not in ("conftest.py", "__init__.py") ) if not files: return 0 diff --git a/scripts/ci/check_workflow_job_names.py b/scripts/ci/check_workflow_job_names.py index 75e2009ea1d5..dde84de8c357 100755 --- a/scripts/ci/check_workflow_job_names.py +++ b/scripts/ci/check_workflow_job_names.py @@ -29,7 +29,7 @@ def main() -> int: job_to_files: dict[str, list[str]] = defaultdict(list) for wf in workflows: - with open(wf) as f: + with open(wf, encoding="utf-8") as f: data = yaml.safe_load(f) if not data or "jobs" not in data: continue