diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index f6a9824161f..48b42241751 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -45,7 +45,7 @@ from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable, mlp_tp_enable, oproj_tp_enable) -from .utils import get_quant_method +from .utils import get_quant_method, is_mx_quant_type @register_quantization_config(ASCEND_QUANTIZATION_METHOD) @@ -393,7 +393,8 @@ def create_weights( set_weight_attrs(param, {"output_dim": 0}) layer.register_parameter(pergroup_name, param) set_weight_attrs(param, extra_weight_attrs) - if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name: + if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \ + or is_mx_quant_type(self.quant_method): setattr(param, "input_dim", 1) param.input_dim = 1 diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 71db5269b09..43b039d9f96 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -14,6 +14,7 @@ AscendW8A8DynamicLinearMethod) from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, AscendW8A8PDMixLinearMethod) +from .w8a8mxfp8 import AscendW8A8MXFP8DynamicLinearMethod from .w8a16 import AscendW8A16LinearMethod ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { @@ -40,7 +41,10 @@ }, "W8A16": { "linear": AscendW8A16LinearMethod, - } + }, + "W8A8_MXFP8": { + "linear": AscendW8A8MXFP8DynamicLinearMethod, + }, } @@ -113,3 +117,9 @@ def get_quant_method_modelslim( ) raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}") + + +def is_mx_quant_type(instance: Any) -> bool: + """Checks if the quantization method is a mix-precision type.""" + MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, ) + return isinstance(instance, MX_QUANT_TYPES) diff --git a/vllm_ascend/quantization/w8a8mxfp8.py b/vllm_ascend/quantization/w8a8mxfp8.py new file mode 100644 index 00000000000..2997a176d70 --- /dev/null +++ b/vllm_ascend/quantization/w8a8mxfp8.py @@ -0,0 +1,98 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, Optional + +import torch +import torch_npu +from vllm.config import get_current_vllm_config + + +class AscendW8A8MXFP8DynamicLinearMethod: + """Linear method for Ascend W8A8_DYNAMIC. + """ + model_dtype = None + + def __init__(self): + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 32) + + @staticmethod + def get_weight(input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + params_dict = { + "weight": + torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn) + } + return params_dict + + @staticmethod + def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + return {} + + @staticmethod + def get_perchannel_param( + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + return {} + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + params_dict = {} + params_dict["weight_scale"] = torch.empty(output_size, + input_size // + self.group_size, + dtype=torch.uint8) + return params_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + + quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant( + x, dst_type=torch.float8_e4m3fn) + pertoken_scale = dynamic_scale + output_dtype = x.dtype + + output = torch_npu.npu_quant_matmul( + quantized_x, + layer.weight, + layer.weight_scale, + scale_dtype=torch_npu.float8_e8m0fnu, + pertoken_scale=pertoken_scale, + pertoken_scale_dtype=torch_npu.float8_e8m0fnu, + bias=bias, + output_dtype=output_dtype, + group_sizes=[1, 1, self.group_size]) + + return output + + def process_weights_after_loading(self, layer): + n_dim, k_dim = layer.weight_scale.data.shape + layer.weight_scale.data = layer.weight_scale.data.reshape( + n_dim, k_dim // 2, 2) + layer.weight.data = layer.weight.data.transpose(0, 1) + layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)