Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable,
mlp_tp_enable, oproj_tp_enable)

from .utils import get_quant_method
from .utils import get_quant_method, is_mx_quant_type


@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
Expand Down Expand Up @@ -393,7 +393,8 @@ def create_weights(
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(pergroup_name, param)
set_weight_attrs(param, extra_weight_attrs)
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name:
if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \
or is_mx_quant_type(self.quant_method):
setattr(param, "input_dim", 1)
param.input_dim = 1

Expand Down
12 changes: 11 additions & 1 deletion vllm_ascend/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AscendW8A8DynamicLinearMethod)
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
AscendW8A8PDMixLinearMethod)
from .w8a8mxfp8 import AscendW8A8MXFP8DynamicLinearMethod
from .w8a16 import AscendW8A16LinearMethod

ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
Expand All @@ -40,7 +41,10 @@
},
"W8A16": {
"linear": AscendW8A16LinearMethod,
}
},
"W8A8_MXFP8": {
"linear": AscendW8A8MXFP8DynamicLinearMethod,
},
}


Expand Down Expand Up @@ -113,3 +117,9 @@ def get_quant_method_modelslim(
)
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")


def is_mx_quant_type(instance: Any) -> bool:
"""Checks if the quantization method is a mix-precision type."""
MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, )
return isinstance(instance, MX_QUANT_TYPES)
98 changes: 98 additions & 0 deletions vllm_ascend/quantization/w8a8mxfp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Any, Dict, Optional

import torch
import torch_npu
from vllm.config import get_current_vllm_config


class AscendW8A8MXFP8DynamicLinearMethod:
"""Linear method for Ascend W8A8_DYNAMIC.
"""
model_dtype = None

def __init__(self):
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 32)

@staticmethod
def get_weight(input_size: int, output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {
"weight":
torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)
}
return params_dict

@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}

@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
return {}

def get_pergroup_param(self,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
layer_type: Optional[str] = None) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size,
input_size //
self.group_size,
dtype=torch.uint8)
return params_dict

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:

quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(
x, dst_type=torch.float8_e4m3fn)
pertoken_scale = dynamic_scale
output_dtype = x.dtype

output = torch_npu.npu_quant_matmul(
quantized_x,
layer.weight,
layer.weight_scale,
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=pertoken_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias,
output_dtype=output_dtype,
group_sizes=[1, 1, self.group_size])

return output

def process_weights_after_loading(self, layer):
n_dim, k_dim = layer.weight_scale.data.shape
layer.weight_scale.data = layer.weight_scale.data.reshape(
n_dim, k_dim // 2, 2)
layer.weight.data = layer.weight.data.transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)