-
Notifications
You must be signed in to change notification settings - Fork 6k
🚧 [llm][npu][quant] Add W4A8 MXFP quantization support for Qwen3 Dense on Ascend NPU #23650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ef874c0
d2d19c6
c838ade
be3b684
fd79b23
df61b29
490ad0b
cc80690
b9aa785
22bee9e
a29bb3d
3bbf703
250fe65
e146b03
711bb8b
1604d4e
1101cf5
f1c652b
97c45b6
6026a18
3025e2d
da92418
29c04bc
6693352
e9dec3c
066becd
93d542a
456bd14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,118 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ModelSlim MXFP8 scheme for pre-quantized weight inference on Ascend NPU. | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Loads weights pre-quantized by msmodelslim (float8_e4m3fn weights, | ||||||||||||||||||||||||||||||||||||||||||||||||
| uint8 scales) and runs MXFP8 matmul at inference. | ||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import List, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.models.parameter import ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| GroupQuantScaleParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ModelWeightParameter, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| MXFP8_BLOCK_SIZE = 32 | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| class ModelSlimMXFP8Scheme(ModelSlimLinearScheme): | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def create_weights( | ||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer: torch.nn.Module, | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_size_per_partition: int, | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_partition_sizes: List[int], | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_size: int, | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_size: int, | ||||||||||||||||||||||||||||||||||||||||||||||||
| params_dtype: torch.dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||
| **extra_weight_attrs, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader = extra_weight_attrs.get("weight_loader") | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_size_per_partition = sum(output_partition_sizes) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # msmodelslim exports weight as float8_e4m3fn, shape [out, in] | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight = ModelWeightParameter( | ||||||||||||||||||||||||||||||||||||||||||||||||
| data=torch.empty( | ||||||||||||||||||||||||||||||||||||||||||||||||
| (output_size_per_partition, input_size_per_partition), | ||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=torch.float8_e4m3fn, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_dim=1, | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_dim=0, | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader=weight_loader, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer.register_parameter("weight", weight) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # msmodelslim exports weight_scale as uint8, shape [out, in/32]. | ||||||||||||||||||||||||||||||||||||||||||||||||
| # NOTE: This parameter is intentionally named "weight_scale" (not | ||||||||||||||||||||||||||||||||||||||||||||||||
| # "weight_scale_inv" as used in mxfp8_npu.py) because the weight loader | ||||||||||||||||||||||||||||||||||||||||||||||||
| # matches parameter names to checkpoint keys, and msmodelslim checkpoints | ||||||||||||||||||||||||||||||||||||||||||||||||
| # store this tensor under the key "<layer>.weight_scale". | ||||||||||||||||||||||||||||||||||||||||||||||||
| scale_dim = input_size_per_partition // MXFP8_BLOCK_SIZE | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale = GroupQuantScaleParameter( | ||||||||||||||||||||||||||||||||||||||||||||||||
| data=torch.empty( | ||||||||||||||||||||||||||||||||||||||||||||||||
| (output_size_per_partition, scale_dim), | ||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=torch.uint8, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_dim=1, | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_dim=0, | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader=weight_loader, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer.register_parameter("weight_scale", weight_scale) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def process_weights_after_loading(self, layer: torch.nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||||
| # weight is already float8_e4m3fn, no cast needed | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight = layer.weight.data | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight = torch.nn.Parameter(weight, requires_grad=False) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2] | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale = layer.weight_scale.data | ||||||||||||||||||||||||||||||||||||||||||||||||
| weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+65
to
+73
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pre-transposing the weight and scale tensors during model loading avoids the overhead of performing transposes on every forward pass. This optimization is already present in the SRT implementation of this scheme. def process_weights_after_loading(self, layer: torch.nn.Module):
# weight is already float8_e4m3fn, no cast needed
weight = layer.weight.data
# Pre-transpose weight and scale to [in, out] for npu_quant_matmul.
# Use .data assignment without .contiguous() to preserve the transpose
# view strides — npu_quant_matmul reads strides correctly.
layer.weight = torch.nn.Parameter(weight.transpose(0, 1), requires_grad=False)
# Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
weight_scale = layer.weight_scale.data
weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
layer.weight_scale = torch.nn.Parameter(weight_scale.transpose(0, 1), requires_grad=False) |
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def apply_weights( | ||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer: torch.nn.Module, | ||||||||||||||||||||||||||||||||||||||||||||||||
| x: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||
| bias: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| original_dtype = x.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||
| if original_dtype not in (torch.float16, torch.bfloat16): | ||||||||||||||||||||||||||||||||||||||||||||||||
| # npu_dynamic_mx_quant only accepts fp16/bf16 activations | ||||||||||||||||||||||||||||||||||||||||||||||||
| x = x.to(torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||
| original_dtype = torch.bfloat16 | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # npu_dynamic_mx_quant requires a 2D input [tokens, hidden_size]. | ||||||||||||||||||||||||||||||||||||||||||||||||
| # Diffusion transformer inputs are typically 3D [batch, seq, hidden] or | ||||||||||||||||||||||||||||||||||||||||||||||||
| # higher. Flattening to 2D merges all leading dimensions into a single | ||||||||||||||||||||||||||||||||||||||||||||||||
| # token axis so the NPU kernel can compute per-token MXFP8 scales, then | ||||||||||||||||||||||||||||||||||||||||||||||||
| # we restore the original shape from the output. | ||||||||||||||||||||||||||||||||||||||||||||||||
| input_shape = x.shape | ||||||||||||||||||||||||||||||||||||||||||||||||
| x_2d = x.reshape(-1, x.shape[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Dynamic MXFP8 activation quantisation | ||||||||||||||||||||||||||||||||||||||||||||||||
| qx, input_scale = torch_npu.npu_dynamic_mx_quant( | ||||||||||||||||||||||||||||||||||||||||||||||||
| x_2d, dst_type=torch_npu.float8_e4m3fn | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # MXFP8 matmul | ||||||||||||||||||||||||||||||||||||||||||||||||
| output = torch_npu.npu_quant_matmul( | ||||||||||||||||||||||||||||||||||||||||||||||||
| qx, | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight.transpose(0, 1), | ||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight_scale.transpose(0, 1), | ||||||||||||||||||||||||||||||||||||||||||||||||
| scale_dtype=torch_npu.float8_e8m0fnu, | ||||||||||||||||||||||||||||||||||||||||||||||||
| pertoken_scale=input_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||
| pertoken_scale_dtype=torch_npu.float8_e8m0fnu, | ||||||||||||||||||||||||||||||||||||||||||||||||
| bias=bias.to(torch.float32) if bias is not None else None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_dtype=original_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+102
to
+112
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the pre-transposed weights and the robust dtype fallback in the matmul call.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # Restore original shape | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_shape = list(input_shape[:-1]) + [output.shape[-1]] | ||||||||||||||||||||||||||||||||||||||||||||||||
| output = output.reshape(output_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,170 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Online MXFP8 quantization for Diffusion models on Ascend NPU. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Provides ``MXFP8Config`` (registered as ``"mxfp8"``) and | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ``NPUMXFP8DiffusionLinearMethod`` which quantise FP16/BF16 weights to MXFP8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| at load time and use ``npu_dynamic_mx_quant`` + ``npu_quant_matmul`` for | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference, mirroring the LLM-side ``NPUMXFP8LinearMethod``. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Dict, List, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.nn.parameter import Parameter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.layers.linear import LinearBase, LinearMethodBase | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.layers.quantization.configs.base_config import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| QuantizationConfig, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| QuantizeMethodBase, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.models.parameter import ModelWeightParameter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| MXFP8_BLOCK_SIZE = 32 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class MXFP8Config(QuantizationConfig): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Config for online MXFP8 quantization on Ascend NPU (Diffusion).""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_name(cls) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return "mxfp8" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [torch.bfloat16, torch.float16] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_min_capability(cls) -> int: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 0 # NPU, not CUDA | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_config_filenames(cls) -> List[str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @classmethod | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def from_config(cls, config: Dict[str, Any]) -> "MXFP8Config": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return cls() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_quant_method( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, layer: torch.nn.Module, prefix: str | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Optional[QuantizeMethodBase]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(layer, LinearBase): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return NPUMXFP8DiffusionLinearMethod(self) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_scaled_act_names(self) -> List[str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class NPUMXFP8DiffusionLinearMethod(LinearMethodBase): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Ascend NPU MXFP8 linear method for Diffusion models. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Online mode: loads FP16/BF16 weights → quantises to MXFP8 at load time. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Inference: dynamic MXFP8 activation quant + MXFP8 matmul (block_size=32). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, quant_config: MXFP8Config): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.quant_config = quant_config | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def create_weights( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer: torch.nn.Module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_size_per_partition: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_partition_sizes: List[int], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_size: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_size: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| params_dtype: torch.dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **extra_weight_attrs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_size_per_partition = sum(output_partition_sizes) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader = extra_weight_attrs.get("weight_loader") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.logical_widths = output_partition_sizes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.input_size_per_partition = input_size_per_partition | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.output_size_per_partition = output_size_per_partition | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.orig_dtype = params_dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Load weights in original dtype; quantise later in process_weights_after_loading | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight = ModelWeightParameter( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data=torch.empty( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_size_per_partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_size_per_partition, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype=params_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_dim=1, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dim=0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_loader=weight_loader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.register_parameter("weight", weight) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_fp = layer.weight.data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if weight_fp.dtype not in (torch.float16, torch.bfloat16): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_fp = weight_fp.to(torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Move weight to NPU if needed. We intentionally use a conditional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # move rather than an assert because `dit_cpu_offload` defaults to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # True in ServerArgs, which causes fsdp_load to move every parameter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # back to CPU after loading (even when the target device is NPU). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # npu_dynamic_mx_quant requires an NPU tensor, so we must transfer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # here. The quantized fp8 weights produced below will remain on NPU | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # for inference; if the model still needs to be offloaded after | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # quantization (e.g. very large model on a small NPU), a higher-level | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # offload pass can move them back afterwards. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not weight_fp.is_npu: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Online MXFP8 quantisation of weights (block_size=32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qw, w_scale = torch_npu.npu_dynamic_mx_quant( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weight_fp, dst_type=torch_npu.float8_e4m3fn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight = Parameter(qw, requires_grad=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+108
to
+131
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pre-transpose the weights and scales during model loading to improve inference performance. Since this is online quantization, using
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def apply( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer: torch.nn.Module, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x: torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bias: Optional[torch.Tensor] = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| original_dtype = x.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if original_dtype not in (torch.float16, torch.bfloat16): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = x.to(torch.bfloat16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| original_dtype = torch.bfloat16 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Flatten to 2D [tokens, hidden] so npu_dynamic_mx_quant returns 3D scale | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_shape = x.shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x_2d = x.reshape(-1, x.shape[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Dynamic MXFP8 activation quantisation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qx, input_scale = torch_npu.npu_dynamic_mx_quant( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x_2d, dst_type=torch_npu.float8_e4m3fn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # MXFP8 matmul | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output = torch_npu.npu_quant_matmul( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| qx, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight.transpose(0, 1), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer.weight_scale_inv.transpose(0, 1), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| scale_dtype=torch_npu.float8_e8m0fnu, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pertoken_scale=input_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pertoken_scale_dtype=torch_npu.float8_e8m0fnu, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bias=bias.to(torch.float32) if bias is not None else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_dtype=original_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+154
to
+164
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the matmul call to use the pre-transposed parameters and the robust dtype fallback.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Restore original shape (replace last dim with output features) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_shape = list(input_shape[:-1]) + [output.shape[-1]] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output = output.reshape(output_shape) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return output | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For robustness across different versions of
torch_npuandtorch, it is better to use a fallback mechanism for thefloat8_e8m0fnudtype, similar to the implementation in the SRT backend.