diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index d74e073feebb..11b10037fefe 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -16,6 +16,7 @@ /python/sglang/srt/function_call @CatherineSue @JustinTong0323 /python/sglang/srt/grpc @CatherineSue @slin1237 /python/sglang/srt/hardware_backend/npu @ping1jing2 @iforgetmyname +/python/sglang/srt/hardware_backend/npu/quantization @OrangeRedeng @TamirBaydasov @iforgetmyname /python/sglang/srt/layers @merrymercy @Ying1123 @Fridge003 @ispobock @HaiShaw @ch-wan @BBuf @Edwardf0t1 /python/sglang/srt/layers/attention @merrymercy @Fridge003 @ispobock @Qiaolin-Yu @hebiao064 /python/sglang/srt/layers/attention/fla @yizhang2077 @hebiao064 diff --git a/docs/platforms/ascend_npu_deepseek_example.md b/docs/platforms/ascend_npu_deepseek_example.md index d75b942704b2..612f0ed6f2f0 100644 --- a/docs/platforms/ascend_npu_deepseek_example.md +++ b/docs/platforms/ascend_npu_deepseek_example.md @@ -30,7 +30,6 @@ python3 -m sglang.launch_server \ --trust-remote-code \ --attention-backend ascend \ --device npu \ - --quantization modelslim \ --watchdog-timeout 9000 \ --cuda-graph-bs 8 16 24 28 32 \ --mem-fraction-static 0.68 \ @@ -88,7 +87,6 @@ python -m sglang.launch_server \ --mem-fraction-static 0.6 \ --attention-backend ascend \ --device npu \ - --quantization modelslim \ --max-running-requests 8 \ --context-length 8192 \ --disable-radix-cache \ @@ -144,7 +142,6 @@ python -m sglang.launch_server \ --max-running-requests 352 \ --attention-backend ascend \ --device npu \ - --quantization modelslim \ --moe-a2a-backend deepep \ --enable-dp-attention \ --deepep-mode low_latency \ @@ -216,7 +213,6 @@ do --mem-fraction-static 0.81 \ --attention-backend ascend \ --device npu \ - --quantization modelslim \ --max-running-requests 8 \ --context-length 8192 \ --disable-radix-cache \ @@ -279,7 +275,6 @@ do --max-running-requests 832 \ --attention-backend ascend \ --device npu \ - --quantization modelslim \ --moe-a2a-backend deepep \ --enable-dp-attention \ --deepep-mode low_latency \ diff --git a/docs/platforms/ascend_npu_quantization.md b/docs/platforms/ascend_npu_quantization.md new file mode 100644 index 000000000000..4c40fde6e170 --- /dev/null +++ b/docs/platforms/ascend_npu_quantization.md @@ -0,0 +1,21 @@ +Quantization on Ascend. + +To load already quantized models, simply load the model weights and config. Again, if the model has been quantized offline, there's no need to add `--quantization` argument when starting the engine. The quantization method will be automatically parsed from the downloaded `quant_model_description.json` or `config.json` config. + +[ModelSlim on Ascend support](https://github.com/sgl-project/sglang/pull/14504): +- [x] W4A4 dynamic linear +- [x] W8A8 static linear +- [x] W8A8 dynamic linear +- [x] W4A8 dynamic MOE +- [x] W8A8 dynamic MOE + +[AWQ on Ascend support](https://github.com/sgl-project/sglang/pull/10158): +- [x] W4A16 linear +- [x] W8A16 linear # Need to test +- [x] W4A16 MOE # Need to test + +Compressed-tensors (LLM Compressor) on Ascend support: +- [x] [W4A8 dynamic MOE with/without activation clip](https://github.com/sgl-project/sglang/pull/14736) # Need to test +- [x] [W4A16 MOE](https://github.com/sgl-project/sglang/pull/12759) +- [x] [W8A8 dynamic linear](https://github.com/sgl-project/sglang/pull/14504) +- [x] [W8A8 dynamic MOE](https://github.com/sgl-project/sglang/pull/14504) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index efed08608c14..64b871485dff 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -17,6 +17,7 @@ import math import os from enum import Enum, IntEnum, auto +from pathlib import Path from typing import Any, List, Optional, Set, Union import torch @@ -632,6 +633,18 @@ def _parse_quant_hf_config(self): quant_cfg = self._parse_modelopt_quant_config(quant_config_dict) return quant_cfg + def _find_quant_modelslim_config(self): + quant_config_file = Path(self.model_path, "quant_model_description.json") + quant_cfg = None + if quant_config_file.is_file(): + with open(quant_config_file) as f: + quant_cfg = json.load(f) + # This field is required for flagless model loading but is not present in + # modelslim model description, so we're adding it here manually. + quant_cfg["quant_method"] = "modelslim" + + return quant_cfg + def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]: """Parse ModelOpt quantization config and return the appropriate quant_method.""" json_quant_configs = quant_config_dict["quantization"] @@ -744,6 +757,7 @@ def _verify_quantization(self) -> None: "w4afp8", "petit_nvfp4", "quark", + "modelslim", ] compatible_quantization_methods = { "modelopt_fp8": ["modelopt"], @@ -755,8 +769,19 @@ def _verify_quantization(self) -> None: if self.quantization is not None: self.quantization = self.quantization.lower() - # Parse quantization method from the HF model config, if available. - quant_cfg = self._parse_quant_hf_config() + # Parse quantization method from the HF and ModelSlim model config, if available. + # Only one function should return config, other should return None. + cfg_list = [] + cfg_list.append(self._parse_quant_hf_config()) + cfg_list.append(self._find_quant_modelslim_config()) + + # Filter out None values + cfg_list = [item for item in cfg_list if item is not None] + if len(cfg_list) > 1: + raise ValueError( + "Config list contains configs from 2 methods, must be only 1" + ) + quant_cfg = cfg_list[0] if cfg_list else None if quant_cfg is not None: quant_method = quant_cfg.get( diff --git a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py index e1b2f6e2b378..91a5da075807 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py @@ -1,18 +1,17 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import numpy as np import torch from sglang.srt.hardware_backend.npu.utils import npu_format_cast from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase -from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: - from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import ( CombineInput, StandardDispatchOutput, ) + from sglang.srt.layers.quantization.base_config import QuantizationConfig def npu_fused_experts( @@ -140,79 +139,18 @@ def npu_fused_moe_without_routing_weights_bf16( return hidden_states -class NPUW8A8Int8DynamicMoEMethod(FusedMoEMethodBase): +class _NPUFusedMoEMethodBase(FusedMoEMethodBase): - def create_weights( + def __init__( self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + quant_config: Optional["QuantizationConfig"] = None, + ): + self.quant_config = quant_config - self.num_experts = num_experts - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} - ) - # weight - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=torch.int8, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=torch.int8, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - # scale - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w2_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - # offset - w13_weight_offset = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_offset", w13_weight_offset) - set_weight_attrs(w13_weight_offset, extra_weight_attrs) - w2_weight_offset = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_offset", w2_weight_offset) - set_weight_attrs(w2_weight_offset, extra_weight_attrs) +class NPUW8A8Int8DynamicMoEMethod(_NPUFusedMoEMethodBase): - def release_weight_cache(self, weight: torch.Tensor): + def _release_weight_cache(self, weight: torch.Tensor): # .contiguous() introduces additional memory overhead and needs to be released using resize_(0) origin_weight = weight.data.transpose(1, 2) new_weight = origin_weight.contiguous() @@ -220,10 +158,10 @@ def release_weight_cache(self, weight: torch.Tensor): return new_weight def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight_data = self.release_weight_cache(layer.w13_weight.data) + weight_data = self._release_weight_cache(layer.w13_weight.data) layer.w13_weight = torch.nn.Parameter(weight_data, requires_grad=False) - weight_data = self.release_weight_cache(layer.w2_weight.data) + weight_data = self._release_weight_cache(layer.w2_weight.data) layer.w2_weight = torch.nn.Parameter(weight_data, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter( @@ -233,21 +171,21 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight_scale = torch.nn.Parameter( layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False ) - layer.w13_weight_offset = torch.nn.Parameter( - layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False - ) - layer.w2_weight_offset = torch.nn.Parameter( - layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False - ) + # Compressed-tensors format doesn't have this field + if hasattr(layer, "w13_weight_offset"): + layer.w13_weight_offset = torch.nn.Parameter( + layer.w13_weight_offset.data.squeeze(-1).contiguous(), + requires_grad=False, + ) + if hasattr(layer, "w2_weight_offset"): + layer.w2_weight_offset = torch.nn.Parameter( + layer.w2_weight_offset.data.squeeze(-1).contiguous(), + requires_grad=False, + ) layer.w13_weight.data = npu_format_cast(layer.w13_weight.data) layer.w2_weight.data = npu_format_cast(layer.w2_weight.data) - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: "MoeRunnerConfig" - ): - self.moe_runner_config = moe_runner_config - def apply( self, layer, @@ -321,237 +259,19 @@ def apply_without_routing_weights( return hidden_states -class NPUW4A8Int4DynamicMoEMethod(FusedMoEMethodBase): - - def __init__(self, activation_use_clip: bool) -> None: - self.group_size = 0 - self.tp_size = 1 - self.activation_use_clip = activation_use_clip - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - self.is_per_channel_weight = self.group_size == 0 - self.num_experts = num_experts - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} - ) - - # >> weight - w13_output_size = intermediate_size_per_partition - w2_output_size = hidden_size // 2 - w13_weight = torch.nn.Parameter( - torch.empty(num_experts, w13_output_size, hidden_size, dtype=torch.int8), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w2_output_size, - intermediate_size_per_partition, - dtype=torch.int8, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # >> scale - weight_scale_dtype = torch.int64 if self.activation_use_clip else torch.float32 - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, 1, dtype=weight_scale_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # >> offset - w13_weight_offset = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_offset", w13_weight_offset) - set_weight_attrs(w13_weight_offset, extra_weight_attrs) - - w2_weight_offset = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_offset", w2_weight_offset) - set_weight_attrs(w2_weight_offset, extra_weight_attrs) - - # >>> special param for w4a8 - if self.activation_use_clip: - self._init_activation_clip_params( - layer, - num_experts, - hidden_size, - intermediate_size_per_partition, - extra_weight_attrs, - ) - else: - self._init_extra_scale_params( - layer, - num_experts, - hidden_size, - intermediate_size_per_partition, - extra_weight_attrs, - ) - - def _init_activation_clip_params( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - extra_weight_attrs: dict, - ) -> None: - """ - Initializes bias and alpha parameters for quantization schemes that use activation clipping. - - This helper registers `w13_bias`, `w2_bias`, and `w2_alpha`, which are required to - shift and scale the activations or outputs to compensate for the precision loss - introduced by clamping activations. - """ - w13_bias = torch.nn.Parameter( - torch.ones( - num_experts, 2 * intermediate_size_per_partition, dtype=torch.float - ), - requires_grad=False, - ) - layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, extra_weight_attrs) - - w2_bias = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, dtype=torch.float), - requires_grad=False, - ) - layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, extra_weight_attrs) - - w2_alpha = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float), requires_grad=False - ) - layer.register_parameter("w2_alpha", w2_alpha) - set_weight_attrs(w2_alpha, extra_weight_attrs) - - def _init_extra_scale_params( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - extra_weight_attrs: dict, - ) -> None: - """ - Initializes additional scaling, offset, and bias parameters for quantization schemes without activation clipping. - - This method registers the following parameters: - 1. Scale Biases: `w13_scale_bias` and `w2_scale_bias`. - 2. Secondary Quantization Params (initialized only for grouped quantization): - `w13_weight_scale_second`, `w13_weight_offset_second`, - `w2_weight_scale_second`, and `w2_weight_offset_second`. - """ - if not self.is_per_channel_weight: - w13_weight_scale_second = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale_second", w13_weight_scale_second) - set_weight_attrs(w13_weight_scale_second, extra_weight_attrs) - - w13_weight_offset_second = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter( - "w13_weight_offset_second", w13_weight_offset_second - ) - set_weight_attrs(w13_weight_offset_second, extra_weight_attrs) - - w2_weight_scale_second = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale_second", w2_weight_scale_second) - set_weight_attrs(w2_weight_scale_second, extra_weight_attrs) - - w2_weight_offset_second = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_offset_second", w2_weight_offset_second) - set_weight_attrs(w2_weight_offset_second, extra_weight_attrs) - - w13_scale_bias = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_scale_bias", w13_scale_bias) - set_weight_attrs(w13_scale_bias, extra_weight_attrs) - - w2_scale_bias = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, 16 // self.tp_size, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w2_scale_bias", w2_scale_bias) - set_weight_attrs(w2_scale_bias, extra_weight_attrs) +class NPUW4A8Int8DynamicMoEMethod(_NPUFusedMoEMethodBase): - def process_scale(self, weight: torch.Tensor, scale, per_group_scale): + def _process_scale( + self, weight: torch.Tensor, scale, per_group_scale, is_per_channel_weight + ): scale = scale.transpose(1, 2).contiguous() - if self.is_per_channel_weight: + + if is_per_channel_weight: scale_np = scale.cpu().numpy() scale_np.dtype = np.uint32 scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu() return scale_uint64_tensor, None + per_group_scale = per_group_scale.transpose(1, 2).contiguous() group_num, k, n = weight.shape # the weight of the new version is reduced by half by pack n, so it needs to be restored @@ -576,7 +296,7 @@ def process_scale(self, weight: torch.Tensor, scale, per_group_scale): sscale_uint64_tensor = sscale_uint64_tensor.npu() return sscale_uint64_tensor, bias - def update_bias(self, layer, w13_bias, w2_bias): + def _update_bias(self, layer, w13_bias, w2_bias): layer.w13_scale_bias.data = ( layer.w13_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1) ) @@ -584,16 +304,18 @@ def update_bias(self, layer, w13_bias, w2_bias): layer.w2_scale_bias.data.transpose(1, 2).contiguous().sum(axis=1) ) - def pack_to_int32(self, weight: torch.Tensor): + def _pack_to_int32(self, weight: torch.Tensor): # pack 4 int8(int4*2) to int32, because in pytorch, we need to use int32 to represent int4 assert ( weight.shape[-1] % 4 == 0 ), "the last dim of weight needs to be divided by 4" return weight.view(torch.int32).contiguous() - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if not self.activation_use_clip: - self._process_weights_without_clip(layer) + def process_weights_after_loading( + self, layer: torch.nn.Module, is_per_channel_weight, activation_use_clip + ) -> None: + if not activation_use_clip: + self._process_weights_without_clip(layer, is_per_channel_weight) else: self._process_weights_with_clip(layer) @@ -607,10 +329,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight.data = npu_format_cast(layer.w13_weight.data) layer.w2_weight.data = npu_format_cast(layer.w2_weight.data) - layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) - layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) + layer.w13_weight.data = self._pack_to_int32(layer.w13_weight.data) + layer.w2_weight.data = self._pack_to_int32(layer.w2_weight.data) - def _process_weights_without_clip(self, layer: torch.nn.Module) -> None: + def _process_weights_without_clip( + self, layer: torch.nn.Module, is_per_channel_weight + ) -> None: w13_weight_scale_second = ( layer.w13_weight_scale_second.data if hasattr(layer, "w13_weight_scale_second") @@ -621,11 +345,17 @@ def _process_weights_without_clip(self, layer: torch.nn.Module) -> None: if hasattr(layer, "w2_weight_scale_second") else None ) - layer.w13_weight_scale.data, w13_bias = self.process_scale( - layer.w13_weight, layer.w13_weight_scale.data, w13_weight_scale_second + layer.w13_weight_scale.data, w13_bias = self._process_scale( + layer.w13_weight, + layer.w13_weight_scale.data, + w13_weight_scale_second, + is_per_channel_weight, ) - layer.w2_weight_scale.data, w2_bias = self.process_scale( - layer.w2_weight, layer.w2_weight_scale.data, w2_weight_scale_second + layer.w2_weight_scale.data, w2_bias = self._process_scale( + layer.w2_weight, + layer.w2_weight_scale.data, + w2_weight_scale_second, + is_per_channel_weight, ) if hasattr(layer, "w13_weight_scale_second"): # scale_second is no longer used, release this part of the memory @@ -634,7 +364,7 @@ def _process_weights_without_clip(self, layer: torch.nn.Module) -> None: del layer.w13_weight_offset_second del layer.w2_weight_offset_second - self.update_bias(layer, w13_bias, w2_bias) + self._update_bias(layer, w13_bias, w2_bias) def _process_weights_with_clip(self, layer: torch.nn.Module) -> None: w13_weight_scale = ( @@ -650,20 +380,89 @@ def _process_weights_with_clip(self, layer: torch.nn.Module) -> None: layer.w13_scale_bias = layer.w13_bias layer.w2_scale_bias = layer.w2_bias - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: "MoeRunnerConfig" - ): - self.moe_runner_config = moe_runner_config - def apply( self, layer, dispatch_output: "StandardDispatchOutput", ) -> "CombineInput": - # FIXME W4A8 only support with deepep - raise NotImplementedError( - f"W4A8 only support with deepep for now, please enable --moe-a2a-backend deepep" + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + hidden_states = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + topk_weights, topk_ids, _ = topk_output + top_k = topk_ids.shape[1] + group_list_type = 1 + original_shape = hidden_states.shape + topk_weights = topk_weights + + num_tokens = hidden_states.shape[:-1].numel() + + first_expert_idx = 0 + last_expert_idx = layer.num_experts + global_num_experts = layer.num_experts + + sorted_hidden_states, expanded_row_idx, expert_tokens, pertoken_scale = ( + torch.ops.npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + active_num=num_tokens * top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[first_expert_idx, last_expert_idx], + quant_mode=1, + ) + ) + + expert_tokens = expert_tokens.to(torch.int64) + + bias1 = [layer.w13_scale_bias] + bias2 = [layer.w2_scale_bias] + w1_scale = [layer.w13_weight_scale] + w2_scale = [layer.w2_weight_scale] + _output_dtype = torch.bfloat16 + + hidden_states = torch.ops.npu.npu_grouped_matmul( + x=[sorted_hidden_states], + weight=[layer.w13_weight], + scale=w1_scale, + bias=bias1, + per_token_scale=[pertoken_scale], + group_list=expert_tokens, + split_item=2, + group_type=0, + group_list_type=group_list_type, + output_dtype=_output_dtype, + )[0] + + # act_fn: swiglu + hidden_states = torch.ops.npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch.ops.npu.npu_dynamic_quant(hidden_states) + + output = torch.ops.npu.npu_grouped_matmul( + x=[hidden_states], + weight=[layer.w2_weight], + scale=w2_scale, + bias=bias2, + per_token_scale=[swiglu_out_scale], + group_list=expert_tokens, + split_item=2, + group_type=0, + group_list_type=group_list_type, + output_dtype=_output_dtype, + )[0] + + assert original_shape is not None + final_hidden_states = torch.ops.npu.npu_moe_token_unpermute( + permuted_tokens=output, + sorted_indices=torch.abs(expanded_row_idx), + probs=topk_weights, ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + + return StandardCombineInput(hidden_states=final_hidden_states) def apply_without_routing_weights( self, @@ -709,118 +508,9 @@ def apply_without_routing_weights( return hidden_states -class NPUW4A16Int4DynamicMoEMethod(FusedMoEMethodBase): - - def __init__(self, quantization_config) -> None: - self.pack_factor = 8 # weight dtype is int4, but use int32 to create - target = ( - "MoEGMM" if "MoEGMM" in quantization_config.target_scheme_map else "Linear" - ) - if target in quantization_config.target_scheme_map: - self.group_size = quantization_config.target_scheme_map[target][ - "weights" - ].group_size - else: - self.group_size = 128 - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - self.num_experts = num_experts - if ( - extra_weight_attrs.get( - "intermediate_size_full", intermediate_size_per_partition - ) - // intermediate_size_per_partition - > 1 - ): - quant_method = FusedMoeWeightScaleSupported.GROUP.value - else: - quant_method = FusedMoeWeightScaleSupported.CHANNEL.value - extra_weight_attrs.update({"quant_method": quant_method}) - # weight - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.pack_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.pack_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # scale - weight_scale_dtype = torch.bfloat16 - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w2_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # offset - w13_weight_offset = torch.nn.Parameter( - torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_offset", w13_weight_offset) - set_weight_attrs(w13_weight_offset, extra_weight_attrs) - - w2_weight_offset = torch.nn.Parameter( - torch.zeros( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_offset", w2_weight_offset) - set_weight_attrs(w2_weight_offset, extra_weight_attrs) +class NPUW4A16Int4DynamicMoEMethod(_NPUFusedMoEMethodBase): - def pack_to_int32(self, weight: torch.Tensor): + def _pack_to_int32(self, weight: torch.Tensor): assert weight.dim() == 3 if weight.dtype == torch.int32: # pack 8 int4 to int32, we use a int32 to represent a int4 @@ -841,7 +531,7 @@ def pack_to_int32(self, weight: torch.Tensor): raise ValueError(f"{weight.dtype=} is not supported !") return new_weight - def unpack_from_int32( + def _unpack_from_int32( self, value: torch.Tensor, num_bits: int, @@ -926,31 +616,26 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # w13_weight = layer.w13_weight.data.transpose(1, 2).contiguous() # w2_weight = layer.w2_weight.data.transpose(1, 2).contiguous() unpacked_w13_weight = ( - self.unpack_from_int32(layer.w13_weight.data.flatten(0, 1), 4) + self._unpack_from_int32(layer.w13_weight.data.flatten(0, 1), 4) .view(layer.w13_weight.data.shape[0], layer.w13_weight.data.shape[1], -1) .transpose(1, 2) .contiguous() .int() ) unpacked_w2_weight = ( - self.unpack_from_int32(layer.w2_weight.data.flatten(0, 1), 4) + self._unpack_from_int32(layer.w2_weight.data.flatten(0, 1), 4) .view(layer.w2_weight.data.shape[0], layer.w2_weight.data.shape[1], -1) .transpose(1, 2) .contiguous() .int() ) - w13_weight = self.pack_to_int32(unpacked_w13_weight) - w2_weight = self.pack_to_int32(unpacked_w2_weight) + w13_weight = self._pack_to_int32(unpacked_w13_weight) + w2_weight = self._pack_to_int32(unpacked_w2_weight) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: "MoeRunnerConfig" - ): - self.moe_runner_config = moe_runner_config - def apply( self, layer, diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 46db893b3495..3a99f6ac7c3b 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -1,13 +1,8 @@ -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional import torch from sglang.srt.hardware_backend.npu.utils import npu_format_cast -from sglang.srt.layers.parameter import ( - ChannelQuantScaleParameter, - ModelWeightParameter, - PerTensorScaleParameter, -) from sglang.srt.layers.quantization.base_config import LinearMethodBase if TYPE_CHECKING: @@ -20,82 +15,33 @@ def __init__( self, quant_config: Optional["QuantizationConfig"] = None, ): - super().__init__() self.quant_config = quant_config class NPUW8A8Int8LinearMethod(_NPULinearMethodBase): - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs.get("weight_loader") - output_size_per_partition = sum(output_partition_sizes) - - weight = ModelWeightParameter( - data=torch.empty( - (output_size_per_partition, input_size_per_partition), dtype=torch.int8 - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) - - weight_offset = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_offset", weight_offset) + def process_weights_after_loading(self, layer: torch.nn.Module): + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = npu_format_cast(layer.weight.data) - input_scale = PerTensorScaleParameter( - data=torch.empty(1, dtype=params_dtype), - weight_loader=weight_loader, - ) - input_scale.ignore_warning = True - layer.register_parameter("input_scale", input_scale) + layer.weight_scale.data = layer.weight_scale.data.flatten() + # Compressed-tensors format doesn't have this field + if hasattr(layer, "weight_offset"): + layer.weight_offset.data = layer.weight_offset.data.flatten() - input_offset = PerTensorScaleParameter( - data=torch.empty(1, dtype=params_dtype), - weight_loader=weight_loader, + expanding_factor = layer.weight.data.shape[0] + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, ) - input_offset.ignore_warning = True - layer.register_parameter("input_offset", input_offset) - - quant_bias = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=torch.int32), - output_dim=0, - weight_loader=weight_loader, + layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, ) - layer.register_parameter("quant_bias", quant_bias) - - if params_dtype == torch.bfloat16: - deq_scale_dtype = torch.float32 - elif params_dtype == torch.float16: - deq_scale_dtype = torch.int64 - else: - raise ValueError(f"Unsupported params_dtype: {params_dtype}") - deq_scale = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=deq_scale_dtype), - output_dim=0, - weight_loader=weight_loader, + layer.aclnn_input_offset = torch.nn.Parameter( + layer.input_offset.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, ) - layer.register_parameter("deq_scale", deq_scale) def apply( self, @@ -129,75 +75,58 @@ def apply( output_dtype=original_dtype, ) + +class NPUW8A8Int8DynamicLinearMethod(_NPULinearMethodBase): + def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = npu_format_cast(layer.weight.data) - layer.weight_scale.data = torch.flatten(layer.weight_scale.data) - layer.weight_offset.data = torch.flatten(layer.weight_offset.data) - - expanding_factor = layer.weight.data.shape[0] - layer.aclnn_input_scale = torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor).to(device="npu"), - requires_grad=False, - ) - layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor).to(device="npu"), - requires_grad=False, - ) - layer.aclnn_input_offset = torch.nn.Parameter( - layer.input_offset.data.repeat(expanding_factor).to(device="npu"), - requires_grad=False, - ) - - -class NPUW8A8Int8DynamicLinearMethod(_NPULinearMethodBase): + layer.weight_scale.data = layer.weight_scale.data.flatten() + # Compressed-tensors format doesn't have this field + if hasattr(layer, "weight_offset"): + layer.weight_offset.data = layer.weight_offset.data.flatten() - def create_weights( + def apply( self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs.get("weight_loader") - output_size_per_partition = sum(output_partition_sizes) - - weight = ModelWeightParameter( - data=torch.empty( - (output_size_per_partition, input_size_per_partition), dtype=torch.int8 - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + original_dtype = x.dtype + quant_out, dynamic_scale = torch.ops.npu.npu_dynamic_quant(x) + return torch.ops.npu.npu_quant_matmul( + quant_out, + layer.weight, + layer.weight_scale, + pertoken_scale=dynamic_scale, + bias=bias, + output_dtype=original_dtype, ) - layer.register_parameter("weight", weight) - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) - weight_offset = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), - output_dim=0, - weight_loader=weight_loader, +class NPU_W4A4DynamicLinearMethod(_NPULinearMethodBase): + + def process_weights_after_loading(self, layer): + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight.data = torch.ops.npu.npu_convert_weight_to_int4pack( + layer.weight.data.to(torch.int32) ) - layer.register_parameter("weight_offset", weight_offset) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, ) -> torch.Tensor: original_dtype = x.dtype - quant_out, dynamic_scale = torch.ops.npu.npu_dynamic_quant(x) + quant_out, dynamic_scale = torch.ops.npu.npu_dynamic_quant( + x, dst_type=torch.quint4x2 + ) return torch.ops.npu.npu_quant_matmul( quant_out, layer.weight, @@ -206,10 +135,3 @@ def apply( bias=bias, output_dtype=original_dtype, ) - - def process_weights_after_loading(self, layer: torch.nn.Module): - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight.data = npu_format_cast(layer.weight.data) - - layer.weight_scale.data = layer.weight_scale.data.flatten() - layer.weight_offset.data = layer.weight_offset.data.flatten() diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 53c25b069ffb..552346cb0f6b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -7,9 +7,6 @@ from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph from sglang.srt.environ import envs -from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( - NPUW4A16Int4DynamicMoEMethod, -) from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import ( get_deepep_mode, @@ -27,6 +24,9 @@ ) from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + NPUCompressedTensorsW4A16Int4DynamicMoEMethod, +) from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod @@ -374,7 +374,7 @@ def forward_npu( else: input_quant = get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT") if not input_quant and not isinstance( - self.quant_method, NPUW4A16Int4DynamicMoEMethod + self.quant_method, NPUCompressedTensorsW4A16Int4DynamicMoEMethod ): hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant( hidden_states diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 0e19019534f6..c16d8a1a1a3b 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -31,6 +31,7 @@ def override_quantization_method(self, *args, **kwargs): ModelOptFp4Config, ModelOptFp8Config, ) +from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config from sglang.srt.layers.quantization.petit import PetitNvFp4Config @@ -69,6 +70,7 @@ def override_quantization_method(self, *args, **kwargs): "fbgemm_fp8": FBGEMMFp8Config, "quark": QuarkConfig, "auto-round": AutoRoundConfig, + "modelslim": ModelSlimConfig, "quark_int4fp8_moe": QuarkInt4Fp8Config, } @@ -80,15 +82,6 @@ def override_quantization_method(self, *args, **kwargs): } ) -if is_npu(): - from sglang.srt.hardware_backend.npu.quantization.modelslim import ModelSlimConfig - - BASE_QUANTIZATION_METHODS.update( - { - "modelslim": ModelSlimConfig, - } - ) - QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS} diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 5af8707802d1..173bc2bb2400 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -628,8 +628,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: qzeros_tmp = -(qzeros_tmp - 8) qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype) - layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False) - layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False) + layer.zeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False) + layer.weight = torch.nn.Parameter(qweight_tmp, requires_grad=False) def apply( self, @@ -637,9 +637,9 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - qweight = layer.qweight + qweight = layer.weight scales = layer.scales - qzeros = layer.qzeros + qzeros = layer.zeros pack_factor = self.quant_config.pack_factor out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 94d6a0427c48..48511d09f0bf 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -17,7 +17,6 @@ class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" - @abstractmethod def create_weights( self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs ): @@ -44,7 +43,6 @@ def process_weights_after_loading(self, layer: nn.Module) -> None: class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" - @abstractmethod def create_weights( self, layer: torch.nn.Module, @@ -84,7 +82,6 @@ def apply( class FusedMoEMethodBase(QuantizeMethodBase): - @abstractmethod def create_weights( self, layer: torch.nn.Module, @@ -96,7 +93,6 @@ def create_weights( ): raise NotImplementedError - @abstractmethod def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index e1349794209a..8253f6ebead4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -45,6 +45,7 @@ CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, + NPUCompressedTensorsW8A8Int8, ) from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, @@ -53,6 +54,10 @@ ) from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.utils import is_cuda, is_npu + +_is_cuda = is_cuda() +_is_npu = is_npu() if TYPE_CHECKING: from sglang.srt.models.utils import WeightsMapper @@ -305,6 +310,28 @@ def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bo else: return False + def _is_dynamic_token_w4a8( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + ) + is_token = ( + weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + + return ( + is_weight_4_bits + and is_activation_8_bits + and is_token + and weight_quant.symmetric + and is_dynamic + ) + def _is_static_tensor_w8a8( self, weight_quant: BaseModel, input_quant: BaseModel ) -> bool: @@ -444,6 +471,29 @@ def _is_wNa16_group_channel( return is_channel_group and input_quant_none and is_symmetric and is_static + def _is_dynamic_token_w4( + self, weight_quant: BaseModel, input_quant: BaseModel + ) -> bool: + is_w4 = weight_quant.num_bits == 4 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value + ) + if input_quant is not None: + is_token = ( + weight_strategy + and input_quant.strategy == QuantizationStrategy.TOKEN.value + ) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + else: + is_token = weight_strategy + is_dynamic = not weight_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_w4 and weight_quant.symmetric and is_token and is_dynamic + def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel ) -> CompressedTensorsScheme: @@ -505,18 +555,32 @@ def _get_scheme_from_parts( ) if self._is_static_tensor_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8( - strategy=weight_quant.strategy, - is_static_input_scheme=True, - input_symmetric=input_quant.symmetric, - ) + if not _is_npu: + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=True, + input_symmetric=input_quant.symmetric, + ) + else: + return NPUCompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=True, + input_symmetric=input_quant.symmetric, + ) if self._is_dynamic_token_w8a8(weight_quant, input_quant): - return CompressedTensorsW8A8Int8( - strategy=weight_quant.strategy, - is_static_input_scheme=False, - input_symmetric=input_quant.symmetric, - ) + if not _is_npu: + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False, + input_symmetric=input_quant.symmetric, + ) + else: + return NPUCompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False, + input_symmetric=input_quant.symmetric, + ) raise NotImplementedError("No compressed-tensors compatible scheme was found.") @@ -594,7 +658,9 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - self._check_scheme_supported(scheme.get_min_capability()) + # Note: NPU devices do not support min_capability function + if not _is_npu: + self._check_scheme_supported(scheme.get_min_capability()) logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name) return scheme diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 55f3dda2338d..4b7c5b6c0974 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -15,6 +15,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( + NPUW4A8Int8DynamicMoEMethod, + NPUW4A16Int4DynamicMoEMethod, + NPUW8A8Int8DynamicMoEMethod, +) from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType @@ -43,6 +48,7 @@ get_bool_env_var, is_cuda, is_hip, + is_npu, next_power_of_2, set_weight_attrs, ) @@ -58,6 +64,7 @@ ) _is_hip = is_hip() +_is_npu = is_npu() _is_cuda = is_cuda() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip @@ -79,8 +86,11 @@ class GPTQMarlinState(Enum): __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW4A4Nvfp4MoEMethod", + "NPUCompressedTensorsW4A8Int8DynamicMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "NPUCompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MoEMethod", + "NPUCompressedTensorsW4A16Int4DynamicMoEMethod", ] @@ -103,14 +113,40 @@ def get_moe_method( input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config) + if not _is_npu: + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") + return CompressedTensorsWNA16MoEMethod(quant_config) + else: + if ( + quant_config._is_dynamic_token_w4(weight_quant, input_quant) + and input_quant is None + ): + logger.info_once( + "Using NPUCompressedTensorsW4A16Int4DynamicMoEMethod" + ) + return NPUCompressedTensorsW4A16Int4DynamicMoEMethod(quant_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): logger.info_once("Using CompressedTensorsW4A4Nvfp4MoEMethod") return CompressedTensorsW4A4Nvfp4MoEMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): logger.info_once("Using CompressedTensorsW8A8Fp8MoEMethod") return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + if _is_npu: + logger.info_once("Using NPUCompressedTensorsW8A8Int8DynamicMoEMethod") + return NPUCompressedTensorsW8A8Int8DynamicMoEMethod(quant_config) + else: + raise NotImplementedError( + f"The W8A8Int8 Fused MoE scheme is implemented only for NPU for now." + ) + elif quant_config._is_dynamic_token_w4a8(weight_quant, input_quant): + if _is_npu: + logger.info_once("Using NPUCompressedTensorsW4A8Int8DynamicMoEMethod") + return NPUCompressedTensorsW4A8Int8DynamicMoEMethod(quant_config) + else: + raise NotImplementedError( + f"The W4A8Int8 Fused MoE scheme is implemented only for NPU for now." + ) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" @@ -853,6 +889,117 @@ def apply( return self.runner.run(dispatch_output, quant_info) +class NPUCompressedTensorsW8A8Int8DynamicMoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quant_config: CompressedTensorsConfig): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations" + ) + self.kernel = NPUW8A8Int8DynamicMoEMethod() + + self.static_input_scales = not self.input_quant.dynamic + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not per_channel: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}" + ) + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found static input scales." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert not self.static_input_scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1): @@ -1200,3 +1347,421 @@ def apply( routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, ) return StandardCombineInput(hidden_states=output) + + +class NPUCompressedTensorsW4A8Int8DynamicMoEMethod(CompressedTensorsMoEMethod): + + ### TODO: Get rid of code duplication with python/sglang/srt/modelslim/modelslim_moe.py @OrangeRedeng @TamirBaydasov + def __init__(self, quantization_config) -> None: + self.group_size = 0 + self.is_per_channel_weight = self.group_size == 0 + self.tp_size = 1 + self.activation_use_clip = ( + self.quantization_config.get("config_groups", {}) + .get("group_1", {}) + .get("activation_use_clip", False) + ) + self.kernel = NPUW4A8Int8DynamicMoEMethod() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + # >> weight + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_size // 2 + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, w13_output_size, hidden_size, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w2_output_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # >> scale + weight_scale_dtype = torch.int64 if self.activation_use_clip else torch.float32 + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=weight_scale_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # >> offset + w13_weight_offset = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + # >>> special param for w4a8 + if self.activation_use_clip: + self._init_activation_clip_params( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + extra_weight_attrs, + ) + else: + self._init_extra_scale_params( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + extra_weight_attrs, + ) + + def _init_activation_clip_params( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + extra_weight_attrs: dict, + ) -> None: + """ + Initializes bias and alpha parameters for quantization schemes that use activation clipping. + + This helper registers `w13_bias`, `w2_bias`, and `w2_alpha`, which are required to + shift and scale the activations or outputs to compensate for the precision loss + introduced by clamping activations. + """ + w13_bias = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, dtype=torch.float + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + w2_alpha = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float), requires_grad=False + ) + layer.register_parameter("w2_alpha", w2_alpha) + set_weight_attrs(w2_alpha, extra_weight_attrs) + + def _init_extra_scale_params( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + extra_weight_attrs: dict, + ) -> None: + """ + Initializes additional scaling, offset, and bias parameters for quantization schemes without activation clipping. + + This method registers the following parameters: + 1. Scale Biases: `w13_scale_bias` and `w2_scale_bias`. + 2. Secondary Quantization Params (initialized only for grouped quantization): + `w13_weight_scale_second`, `w13_weight_offset_second`, + `w2_weight_scale_second`, and `w2_weight_offset_second`. + """ + if not self.is_per_channel_weight: + w13_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_second", w13_weight_scale_second) + set_weight_attrs(w13_weight_scale_second, extra_weight_attrs) + + w13_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter( + "w13_weight_offset_second", w13_weight_offset_second + ) + set_weight_attrs(w13_weight_offset_second, extra_weight_attrs) + + w2_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_second", w2_weight_scale_second) + set_weight_attrs(w2_weight_scale_second, extra_weight_attrs) + + w2_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset_second", w2_weight_offset_second) + set_weight_attrs(w2_weight_offset_second, extra_weight_attrs) + + w13_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + set_weight_attrs(w13_scale_bias, extra_weight_attrs) + + w2_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, 16 // self.tp_size, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + set_weight_attrs(w2_scale_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading( + layer, self.is_per_channel_weight, self.activation_use_clip + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) + + +class NPUCompressedTensorsW4A16Int4DynamicMoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quantization_config) -> None: + self.pack_factor = 8 # weight dtype is int4, but use int32 to create + target = ( + "MoEGMM" if "MoEGMM" in quantization_config.target_scheme_map else "Linear" + ) + if target in quantization_config.target_scheme_map: + self.group_size = quantization_config.target_scheme_map[target][ + "weights" + ].group_size + else: + self.group_size = 128 + + self.kernel = NPUW4A16Int4DynamicMoEMethod() + + # TODO: See if we can merge this method's logic + # with CompressedTensorsWNA16MoEMethod. Need more models and tests. + # @OrangeRedeng @TamirBaydasov + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + if ( + extra_weight_attrs.get( + "intermediate_size_full", intermediate_size_per_partition + ) + // intermediate_size_per_partition + > 1 + ): + quant_method = FusedMoeWeightScaleSupported.GROUP.value + else: + quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + extra_weight_attrs.update({"quant_method": quant_method}) + # weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + weight_scale_dtype = torch.bfloat16 + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # offset + w13_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index d5f4146e606f..70ca328c8a91 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,7 +3,10 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 -from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 +from .compressed_tensors_w8a8_int8 import ( + CompressedTensorsW8A8Int8, + NPUCompressedTensorsW8A8Int8, +) from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 @@ -12,6 +15,7 @@ "CompressedTensorsW8A8Fp8", "CompressedTensorsW8A16Fp8", "CompressedTensorsW8A8Int8", + "NPUCompressedTensorsW8A8Int8", "CompressedTensorsWNA16", "WNA16_SUPPORTED_BITS", "CompressedTensorsW4A4Fp4", diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index 2448210be649..efcd4b611fa9 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -7,6 +7,9 @@ from compressed_tensors.quantization import QuantizationStrategy from torch.nn import Parameter +from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUW8A8Int8DynamicLinearMethod, +) from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, ModelWeightParameter, @@ -33,6 +36,61 @@ def __init__( self.is_static_input_scheme = is_static_input_scheme self.input_symmetric = input_symmetric + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter( + data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader + ) + layer.register_parameter("input_scale", input_scale) + + if not self.input_symmetric: + # Note: compressed-tensors stores the zp using the same dtype + # as the weights + # AZP loaded as int8 but used as int32 + input_zero_point = PerTensorScaleParameter( + data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader + ) + layer.register_parameter("input_zero_point", input_zero_point) + @classmethod def get_min_capability(cls) -> int: # ampere and up @@ -107,67 +165,37 @@ def process_weights_after_loading(self, layer) -> None: else: layer.azp_adj = None - def create_weights( - self, - layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, - **kwargs, - ): - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ) -> torch.Tensor: + # TODO: add cutlass_scaled_mm_azp support + x_q, x_scale = per_token_quant_int8(x) - # WEIGHT - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, input_size_per_partition, dtype=torch.int8 - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, + return int8_scaled_mm( + x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias ) - layer.register_parameter("weight", weight) - # WEIGHT SCALE - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader, - ) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) +class NPUCompressedTensorsW8A8Int8(CompressedTensorsW8A8Int8): - # INPUT SCALE + def __init__( + self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool + ): + super().__init__(strategy, is_static_input_scheme, input_symmetric) + # TODO: Currently, NPU kernel for static quant requires quant_bias field, + # which can't be replicated in compressed-tensors. if self.is_static_input_scheme: - input_scale = PerTensorScaleParameter( - data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader + raise NotImplementedError( + "Static compressed-tensors scheme is not yet supported on NPU." ) - layer.register_parameter("input_scale", input_scale) + self.kernel = NPUW8A8Int8DynamicLinearMethod() - if not self.input_symmetric: - # Note: compressed-tensors stores the zp using the same dtype - # as the weights - # AZP loaded as int8 but used as int32 - input_zero_point = PerTensorScaleParameter( - data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader - ) - layer.register_parameter("input_zero_point", input_zero_point) + @classmethod + def get_min_capability(cls) -> int: + return NotImplementedError - def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] - ) -> torch.Tensor: - # TODO: add cutlass_scaled_mm_azp support - x_q, x_scale = per_token_quant_int8(x) + def process_weights_after_loading(self, layer): + return self.kernel.process_weights_after_loading(layer) - return int8_scaled_mm( - x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias - ) + def apply_weights(self, layer, x, bias): + return self.kernel.apply(layer, x, bias) diff --git a/python/sglang/srt/layers/quantization/modelslim/README.md b/python/sglang/srt/layers/quantization/modelslim/README.md new file mode 100644 index 000000000000..d2a43d696741 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/README.md @@ -0,0 +1,14 @@ +Quantization [ModelSlim](https://gitcode.com/Ascend/msit) module. + +`--quantization modelslim` flag introduced. To load already quantized models, simply load the model weights. For models quantized with ModelSlim, there's no need to add `--quantization modelslim` argument when starting the engine. The quantization method will be automatically parsed from the downloaded `quant_model_description.json` config. + +ModelSlim was developed in the format of compressed_tensors and includes support for various quantization schemes, such as: +- [x] W4A4 dynamic linear +- [x] W8A8 static linear +- [x] W8A8 dynamic linear +- [x] W4A8 dynamic MOE +- [x] W8A8 dynamic MOE + +Also ModelSlim module include: +- [x] Automated config detection for modelslim format (without the need to specify --quantization modelslim flag) +- [x] Unit-tests for w4a4 modelslim, w8a8 modelslim diff --git a/python/sglang/srt/hardware_backend/npu/quantization/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py similarity index 63% rename from python/sglang/srt/hardware_backend/npu/quantization/modelslim.py rename to python/sglang/srt/layers/quantization/modelslim/modelslim.py index 3d274496381a..20b6c88a1d9d 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -1,31 +1,30 @@ from __future__ import annotations +import logging from types import MappingProxyType from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast import torch -from compressed_tensors.quantization import QuantizationStrategy -from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( - NPUW4A8Int4DynamicMoEMethod, - NPUW4A16Int4DynamicMoEMethod, - NPUW8A8Int8DynamicMoEMethod, -) from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( - NPUW8A8Int8DynamicLinearMethod, - NPUW8A8Int8LinearMethod, + _NPULinearMethodBase, ) from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( - CompressedTensorsConfig, -) from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer +from sglang.srt.layers.quantization.modelslim.modelslim_moe import ModelSlimMoEMethod +from sglang.srt.layers.quantization.modelslim.schemes import ( + ModelSlimScheme, + ModelSlimW4A4Int4, + ModelSlimW8A8Int8, +) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import apply_module_patch +logger = logging.getLogger(__name__) + # func refers to RMSNorm.__init__ def npu_wrapper_rmsnorm_init(func): @@ -74,33 +73,12 @@ class ModelSlimConfig(QuantizationConfig): def __init__(self, quant_config: Dict[str, Any] = {}): super().__init__() self.quant_description = quant_config - self.is_dynamic = quant_config.get("is_dynamic", False) - self.is_moe_w4_dynamic = False ignore = cast(List[str], quant_config.get("ignore", [])) self.ignore = ignore if ignore is not None else [] packed_modules_mapping = quant_config.get("packed_modules_mapping", {}) self.packed_modules_mapping = ( packed_modules_mapping if packed_modules_mapping is not None else {} ) - self.activation_use_clip = ( - self.quant_description.get("config_groups", {}) - .get("group_1", {}) - .get("activation_use_clip", False) - ) - self.target_scheme_map = ( - CompressedTensorsConfig._quantization_scheme_map_from_config( - config=quant_config - ) - ) - target = "MoEGMM" if "MoEGMM" in self.target_scheme_map else "Linear" - target_scheme = self.target_scheme_map.get(target, None) - if target_scheme is None: - self.is_moe_w4_dynamic = False - else: - weight_quant = target_scheme.get("weights") - input_quant = target_scheme.get("input_activations") - self.is_moe_w4_dynamic = self.is_dynamic_token_w4(weight_quant, input_quant) - self.is_moe_input_quant = input_quant for name in self.quant_description.keys(): if "norm.bias" in name: @@ -115,6 +93,9 @@ def __init__(self, quant_config: Dict[str, Any] = {}): [npu_wrapper_rmsnorm_forward], ) + def get_linear_method(self) -> ModelSlimLinearMethod: + return ModelSlimLinearMethod(self) + @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.int8, torch.float16, torch.bfloat16] @@ -124,7 +105,7 @@ def get_min_capability(cls) -> int: return 0 @classmethod - def get_name(self) -> str: + def get_name(cls) -> str: return "modelslim" @classmethod @@ -163,37 +144,47 @@ def get_quant_method( prefix_in_quant_config = prefix.replace( proj_name, packed_modules_mapping_subset[proj_name][0] ) - self.is_dynamic = ( - self.quant_description.get(prefix_in_quant_config + ".weight", "") - == "W8A8_DYNAMIC" - or self.quant_description.get("quant_method", "") - == "modelslim" # TODO: This path is for compress-tensor config,needs refactor @zhengdqin - ) + if self.is_layer_skipped(prefix, packed_modules_mapping_subset): return UnquantizedLinearMethod() - return ( - NPUW8A8Int8DynamicLinearMethod(self) - if self.is_dynamic - else NPUW8A8Int8LinearMethod(self) - ) + scheme = self.get_scheme(layer=layer, layer_name=prefix_in_quant_config) + layer.scheme = scheme + return ModelSlimLinearMethod(self) elif isinstance(layer, FusedMoE): - prefix_in_quant_config = prefix + ".0.down_proj.weight" - is_moe_w4a8_dynamic = ( - self.quant_description.get(prefix_in_quant_config, "STATIC") - == "W4A8_DYNAMIC" - ) - if ( - self.is_moe_w4_dynamic and self.is_moe_input_quant is not None - ) or is_moe_w4a8_dynamic: - return NPUW4A8Int4DynamicMoEMethod( - activation_use_clip=self.activation_use_clip - ) - elif self.is_moe_w4_dynamic and self.is_moe_input_quant is None: - return NPUW4A16Int4DynamicMoEMethod(self) - else: - return NPUW8A8Int8DynamicMoEMethod() + return ModelSlimMoEMethod.get_moe_method(self, layer, prefix) return None + def _get_scheme_from_parts( + self, + layer_name: str, + ) -> ModelSlimScheme: + + quant_type = self.quant_description.get(layer_name + ".weight", "") + if quant_type == "W8A8_DYNAMIC" or quant_type == "W8A8": + return ModelSlimW8A8Int8( + quant_config=self.quant_description, prefix=layer_name + ) + elif quant_type == "W4A4_DYNAMIC": + return ModelSlimW4A4Int4( + quant_config=self.quant_description, prefix=layer_name + ) + raise NotImplementedError("No modelslim compatible scheme was found.") + + def get_scheme( + self, layer: torch.nn.Module, layer_name: Optional[str] = None + ) -> Optional[ModelSlimScheme]: + """ + get_scheme method adjusted for modelslim, taken from + python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py + """ + scheme = self._get_scheme_from_parts( + layer_name=layer_name, + ) + + # Ascend doesn't support device capability + logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name) + return scheme + def is_layer_skipped( self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) ): @@ -228,23 +219,55 @@ def is_layer_skipped( def get_scaled_act_names(self) -> List[str]: return [] - def is_dynamic_token_w4(self, weight_quant, input_quant) -> bool: - is_w4 = weight_quant.num_bits == 4 - weight_strategy = ( - weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value - or weight_quant.strategy == QuantizationStrategy.GROUP.value + +class ModelSlimLinearMethod(_NPULinearMethodBase): + + def __init__(self, quantization_config: ModelSlimConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """ + Use the ModelSlimScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader, ) - if input_quant is not None: - is_token = ( - weight_strategy - and input_quant.strategy == QuantizationStrategy.TOKEN.value - ) - is_dynamic = not weight_quant.dynamic and input_quant.dynamic - else: - is_token = weight_strategy - is_dynamic = not weight_quant.dynamic - # Both symmetric and asymmetric input quantization supported. - # Only symmetric weight quantization supported. - return is_w4 and weight_quant.symmetric and is_token and is_dynamic + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim_moe.py b/python/sglang/srt/layers/quantization/modelslim/modelslim_moe.py new file mode 100644 index 000000000000..095d09f31155 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim_moe.py @@ -0,0 +1,377 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, Dict + +import torch + +from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( + NPUW4A8Int8DynamicMoEMethod, + NPUW8A8Int8DynamicMoEMethod, +) +from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase +from sglang.srt.utils import set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.moe import MoeRunnerConfig + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + from sglang.srt.layers.quantization.modelslim.modelslim import ModelSlimConfig + +logger = logging.getLogger(__name__) + + +__all__ = [ + "ModelSlimMoEMethod", + "ModelSlimW4A8Int8MoE", + "ModelSlimW8A8Int8MoE", +] + + +class ModelSlimMoEMethod(FusedMoEMethodBase): + def __new__(cls, *args, **kwargs): + if cls is ModelSlimMoEMethod: + return super().__new__(cls) + return super().__new__(cls) + + @staticmethod + def get_moe_method( + quant_config: ModelSlimConfig, + layer: torch.nn.Module, + prefix: str, + ) -> "ModelSlimMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + + prefix_in_quant_config = prefix + ".0.down_proj.weight" + is_moe_w4a8_dynamic = ( + quant_config.quant_description.get(prefix_in_quant_config, "STATIC") + == "W4A8_DYNAMIC" + ) + is_moe_w8a8_dynamic = ( + quant_config.quant_description.get(prefix_in_quant_config, "STATIC") + == "W8A8_DYNAMIC" + ) + if is_moe_w4a8_dynamic: + logger.info_once("Using ModelSlimW4A8Int8MoE") + return ModelSlimW4A8Int8MoE(quant_config) + elif is_moe_w8a8_dynamic: + logger.info_once("Using ModelSlimW8A8Int8MoE") + return ModelSlimW8A8Int8MoE(quant_config) + else: + logger.warning( + f"Unsupported FusedMoe modelslim scheme: \ + {quant_config.quant_description.get(prefix_in_quant_config.strip())} \ + in layer: {prefix}" + ) + return None + + +class ModelSlimW4A8Int8MoE(ModelSlimMoEMethod): + + def __init__( + self, + quant_config: Dict[str, Any], + prefix: str = None, + ): + self.quant_config = quant_config + self.group_size = 0 + self.is_per_channel_weight = self.group_size == 0 + self.tp_size = 1 + self.activation_use_clip = False + self.kernel = NPUW4A8Int8DynamicMoEMethod() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.is_per_channel_weight = self.group_size == 0 + self.num_experts = num_experts + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + # >> weight + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_size // 2 + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, w13_output_size, hidden_size, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w2_output_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # >> scale + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # >> offset + w13_weight_offset = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + # >>> special param for w4a8 + if not self.is_per_channel_weight: + w13_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_second", w13_weight_scale_second) + set_weight_attrs(w13_weight_scale_second, extra_weight_attrs) + w13_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter( + "w13_weight_offset_second", w13_weight_offset_second + ) + set_weight_attrs(w13_weight_offset_second, extra_weight_attrs) + + w2_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_second", w2_weight_scale_second) + set_weight_attrs(w2_weight_scale_second, extra_weight_attrs) + + w2_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset_second", w2_weight_offset_second) + set_weight_attrs(w2_weight_offset_second, extra_weight_attrs) + + w13_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + set_weight_attrs(w13_scale_bias, extra_weight_attrs) + + w2_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, 16 // self.tp_size, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + set_weight_attrs(w2_scale_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading( + layer, self.is_per_channel_weight, self.activation_use_clip + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: "MoeRunnerConfig" + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer, + dispatch_output: "StandardDispatchOutput", + ) -> "CombineInput": + # FIXME W4A8 without EP can give 0 accuracy + return self.kernel.apply(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) + + +class ModelSlimW8A8Int8MoE(ModelSlimMoEMethod): + + def __init__( + self, + quant_config: Dict[str, Any], + prefix: str = None, + ): + self.quant_config = quant_config + self.kernel = NPUW8A8Int8DynamicMoEMethod() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + # weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + # scale + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w2_weight_scale = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + # offset + w13_weight_offset = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + w2_weight_offset = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: "MoeRunnerConfig" + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer, + dispatch_output: "StandardDispatchOutput", + ) -> "CombineInput": + return self.kernel.apply(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py new file mode 100644 index 000000000000..551b862a4424 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .modelslim_scheme import ModelSlimScheme +from .modelslim_w4a4_int4 import ModelSlimW4A4Int4 +from .modelslim_w8a8_int8 import ModelSlimW8A8Int8 + +__all__ = [ + "ModelSlimScheme", + "ModelSlimW8A8Int8", + "ModelSlimW4A4Int4", +] diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_scheme.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_scheme.py new file mode 100644 index 000000000000..1d09c384ca9e --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_scheme.py @@ -0,0 +1,48 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["ModelSlimScheme"] + + +class ModelSlimScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a4_int4.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a4_int4.py new file mode 100644 index 000000000000..8e7f08277f99 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w4a4_int4.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional + +import torch + +from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPU_W4A4DynamicLinearMethod, +) +from sglang.srt.layers.parameter import PerTensorScaleParameter +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimScheme +from sglang.srt.utils import set_weight_attrs + + +class ModelSlimW4A4Int4(ModelSlimScheme): + + def __init__( + self, + quant_config: Dict[str, any], + prefix: str, + ): + self.quant_config = quant_config + self.is_dynamic = self.quant_config[prefix + ".weight"] == "W4A4_DYNAMIC" + self.kernel = NPU_W4A4DynamicLinearMethod() + + @staticmethod + def get_weight( + input_size: int, output_size: int, params_dtype: torch.dtype + ) -> Dict[str, Any]: + params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} + return params_dict + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) + params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) + return params_dict + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + weight_dict = { + "weight": torch.empty( + output_size_per_partition, input_size_per_partition, dtype=torch.int8 + ) + } + for weight_name, weight_param in weight_dict.items(): + param = torch.nn.Parameter(weight_param, requires_grad=False) + set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter(weight_name, param) + set_weight_attrs(param, extra_weight_attrs) + + pertensor_dict = {} + for pertensor_name, pertensor_param in pertensor_dict.items(): + param = PerTensorScaleParameter( + data=pertensor_param, weight_loader=weight_loader + ) + # disable warning + param.ignore_warning = True + layer.register_parameter(pertensor_name, param) + + perchannel_dict = {} + perchannel_dict["weight_scale"] = torch.empty( + output_size_per_partition, 1, dtype=params_dtype + ) + perchannel_dict["weight_offset"] = torch.empty( + output_size_per_partition, 1, dtype=params_dtype + ) + for perchannel_name, perchannel_param in perchannel_dict.items(): + param = torch.nn.Parameter(perchannel_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(perchannel_name, param) + set_weight_attrs(param, extra_weight_attrs) + + def process_weights_after_loading(self, layer): + self.kernel.process_weights_after_loading(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.kernel.apply(layer, x, bias) diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py new file mode 100644 index 000000000000..16c62d551fa3 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py @@ -0,0 +1,117 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional + +import torch + +from sglang.srt.hardware_backend.npu.quantization.linear_method_npu import ( + NPUW8A8Int8DynamicLinearMethod, + NPUW8A8Int8LinearMethod, +) +from sglang.srt.layers.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) +from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimScheme + + +class ModelSlimW8A8Int8(ModelSlimScheme): + + def __init__( + self, + quant_config: Dict[str, any], + prefix: str, + ): + self.quant_config = quant_config + self.is_dynamic = ( + self.quant_config.get(prefix + ".weight", "") == "W8A8_DYNAMIC" + ) + if self.is_dynamic: + self.kernel = NPUW8A8Int8DynamicLinearMethod() + else: + self.kernel = NPUW8A8Int8LinearMethod() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + weight = ModelWeightParameter( + data=torch.empty( + (output_size_per_partition, input_size_per_partition), dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + weight_offset = ChannelQuantScaleParameter( + data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_offset", weight_offset) + + if not self.is_dynamic: + input_scale = PerTensorScaleParameter( + data=torch.empty(1, dtype=params_dtype), + weight_loader=weight_loader, + ) + input_scale.ignore_warning = True + layer.register_parameter("input_scale", input_scale) + + input_offset = PerTensorScaleParameter( + data=torch.empty(1, dtype=params_dtype), + weight_loader=weight_loader, + ) + input_offset.ignore_warning = True + layer.register_parameter("input_offset", input_offset) + + quant_bias = ChannelQuantScaleParameter( + data=torch.empty(output_size_per_partition, dtype=torch.int32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("quant_bias", quant_bias) + + if params_dtype == torch.bfloat16: + deq_scale_dtype = torch.float32 + elif params_dtype == torch.float16: + deq_scale_dtype = torch.int64 + else: + raise ValueError(f"Unsupported params_dtype: {params_dtype}") + deq_scale = ChannelQuantScaleParameter( + data=torch.empty(output_size_per_partition, dtype=deq_scale_dtype), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("deq_scale", deq_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module): + self.kernel.process_weights_after_loading(layer) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.kernel.apply(layer, x, bias) diff --git a/test/manual/ascend/test_ascend_w8a8_quantization.py b/test/manual/ascend/test_ascend_w8a8_quantization.py index bf139f46a872..e013c150c314 100644 --- a/test/manual/ascend/test_ascend_w8a8_quantization.py +++ b/test/manual/ascend/test_ascend_w8a8_quantization.py @@ -45,8 +45,6 @@ def setUpClass(cls): "npu", "--attention-backend", "ascend", - "--quantization", - "w8a8_int8", ], ) diff --git a/test/srt/ascend/test_ascend_deepep.py b/test/srt/ascend/test_ascend_deepep.py index aa1aa62d10ba..e17281094ec4 100644 --- a/test/srt/ascend/test_ascend_deepep.py +++ b/test/srt/ascend/test_ascend_deepep.py @@ -34,8 +34,6 @@ def setUpClass(cls): "--trust-remote-code", "--attention-backend", "ascend", - "--quantization", - "modelslim", "--mem-fraction-static", 0.8, "--disable-radix-cache", diff --git a/test/srt/ascend/test_ascend_deepseek_mtp.py b/test/srt/ascend/test_ascend_deepseek_mtp.py index 2a307b65149d..cbe01a07add1 100644 --- a/test/srt/ascend/test_ascend_deepseek_mtp.py +++ b/test/srt/ascend/test_ascend_deepseek_mtp.py @@ -32,8 +32,6 @@ def setUpClass(cls): "--trust-remote-code", "--attention-backend", "ascend", - "--quantization", - "modelslim", "--mem-fraction-static", 0.8, "--disable-radix-cache", diff --git a/test/srt/ascend/test_ascend_hicache_mla.py b/test/srt/ascend/test_ascend_hicache_mla.py index 5e7c711e868d..d0bc1f378cfa 100644 --- a/test/srt/ascend/test_ascend_hicache_mla.py +++ b/test/srt/ascend/test_ascend_hicache_mla.py @@ -35,8 +35,6 @@ def setUpClass(cls): 0.8, "--attention-backend", "ascend", - "--quantization", - "modelslim", "--tp-size", 4, "--enable-hierarchical-cache", diff --git a/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py b/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py index 1a0eb7f6dd05..bdab4ea05781 100644 --- a/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py +++ b/test/srt/ascend/test_ascend_mla_fia_w8a8int8.py @@ -37,8 +37,6 @@ def setUpClass(cls): 0.8, "--attention-backend", "ascend", - "--quantization", - "modelslim", "--tp-size", 2, "--disable-radix-cache", diff --git a/test/srt/ascend/test_ascend_mla_w8a8int8.py b/test/srt/ascend/test_ascend_mla_w8a8int8.py index eddae3086c6d..3c3e733669ea 100644 --- a/test/srt/ascend/test_ascend_mla_w8a8int8.py +++ b/test/srt/ascend/test_ascend_mla_w8a8int8.py @@ -36,8 +36,6 @@ def setUpClass(cls): 0.8, "--attention-backend", "ascend", - "--quantization", - "modelslim", "--tp-size", 4, "--disable-radix-cache", diff --git a/test/srt/ascend/test_ascend_w4a4_quantization.py b/test/srt/ascend/test_ascend_w4a4_quantization.py new file mode 100644 index 000000000000..22d3f0615181 --- /dev/null +++ b/test/srt/ascend/test_ascend_w4a4_quantization.py @@ -0,0 +1,108 @@ +""" +Usage: +python3 -m unittest test_ascend_w4a4_quantization.TestAscendW4A4.test_gsm8k +""" + +import os +import time +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + +if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1,2,3" +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 +) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" + + +class TestAscendW4A4(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "/root/.cache/modelscope/hub/models/Eco-Tech/Qwen3-32B-w4a4-LAOS" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--device", + "npu", + "--attention-backend", + "ascend", + "--tp-size", + "4", + "--mem-fraction-static", + "0.8", + "--cuda-graph-bs", + "64", + "--disable-radix-cache", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=64, + host=f"http://{url.hostname}", + port=int(url.port), + ) + metrics = run_eval(args) + print(metrics) + + self.assertGreaterEqual(metrics["accuracy"], 0.80) + self.assertGreaterEqual(metrics["output_throughput"], 1000) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.perf_counter() + res = self.run_decode(max_tokens) + tok = time.perf_counter() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + + if is_in_ci(): + self.assertGreaterEqual(throughput, 35) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/ascend/test_ascend_w8a8_quantization.py b/test/srt/ascend/test_ascend_w8a8_quantization.py new file mode 100644 index 000000000000..e0b3545701c6 --- /dev/null +++ b/test/srt/ascend/test_ascend_w8a8_quantization.py @@ -0,0 +1,103 @@ +""" +Usage: +python3 -m unittest test_ascend_w8a8_quantization.TestAscendW8A8.test_gsm8k +""" + +import os +import time +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + +if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 7000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 +) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" + + +class TestAscendW8A8CompressedTensors(CustomTestCase): + @classmethod + def setUpClass(cls): + # TODO: Move model to CI or Modelscope + cls.model = "RedHatAI/Qwen2.5-0.5B-Instruct-quantized.w8a8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--disable-cuda-graph", + "--device", + "npu", + "--attention-backend", + "ascend", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + metrics = run_eval(args) + print(metrics) + + self.assertGreaterEqual(metrics["accuracy"], 0.3) + self.assertGreaterEqual(metrics["output_throughput"], 700) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.perf_counter() + res = self.run_decode(max_tokens) + tok = time.perf_counter() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + + if is_in_ci(): + self.assertGreaterEqual(throughput, 25) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 24c5a5f14009..1c4d06542186 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -160,6 +160,7 @@ TestFile("ascend/test_ascend_sampling_backend.py", 400), TestFile("ascend/test_ascend_tp1_bf16.py", 400), TestFile("ascend/test_ascend_compile_graph_tp1_bf16.py", 400), + TestFile("ascend/test_ascend_w8a8_quantization.py", 400), TestFile("test_embed_interpolate_unittest.py", 400), ], "per-commit-2-npu-a2": [ @@ -174,8 +175,9 @@ TestFile("ascend/test_ascend_tp4_bf16.py", 400), ], "per-commit-16-npu-a3": [ - TestFile("ascend/test_ascend_deepep.py", 400), - TestFile("ascend/test_ascend_deepseek_mtp.py", 400), + TestFile("ascend/test_ascend_deepep.py", 3600), + TestFile("ascend/test_ascend_deepseek_mtp.py", 2800), + TestFile("ascend/test_ascend_w4a4_quantization.py", 600), ], }