From 93624e6dfd238614d300d08e9f810d67da67ed4d Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Mon, 9 Feb 2026 14:09:47 +0000 Subject: [PATCH 1/8] [Quantization] Refactor compressed-tensors quantization implement to reuse upstream implement. And add w4a16 support. Signed-off-by: menogrey <1299267905@qq.com> --- vllm_ascend/ops/fused_moe/fused_moe.py | 6 +- vllm_ascend/ops/layernorm.py | 6 +- vllm_ascend/patch/worker/__init__.py | 1 + .../patch/worker/patch_quantization.py | 13 ++ .../compressed_tensors/__init__.py | 1 + .../compressed_tensors/schemes/__init__.py | 0 .../compressed_tensors/schemes/wNa16.py | 154 ++++++++++++++++++ .../quantization/compressed_tensors_config.py | 12 +- vllm_ascend/quantization/kernels/__init__.py | 1 + .../kernels/mixed_precision/__init__.py | 0 .../kernels/mixed_precision/npu.py | 66 ++++++++ vllm_ascend/quantization/methods/w4a16.py | 71 +------- vllm_ascend/quantization/utils.py | 72 ++++++++ 13 files changed, 320 insertions(+), 83 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_quantization.py create mode 100644 vllm_ascend/quantization/compressed_tensors/__init__.py create mode 100644 vllm_ascend/quantization/compressed_tensors/schemes/__init__.py create mode 100644 vllm_ascend/quantization/compressed_tensors/schemes/wNa16.py create mode 100644 vllm_ascend/quantization/kernels/__init__.py create mode 100644 vllm_ascend/quantization/kernels/mixed_precision/__init__.py create mode 100644 vllm_ascend/quantization/kernels/mixed_precision/npu.py diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 7d7b581d64a..dfec5c75386 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -315,7 +315,11 @@ def __init__(self, *args, **kwargs): "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if self.quant_method.__class__.__name__ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod"): + if self.quant_method.__class__.__name__ in ( + "GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 17214afbddf..8498d896e77 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -38,8 +38,10 @@ def __init__( vllm_config = get_current_vllm_config() self.bias = None # quantization with anti_method m4 will generate none-zero norm bias - if vllm_config.quant_config is not None and any( - "norm.bias" in name for name in vllm_config.quant_config.quant_description + if ( + vllm_config.quant_config is not None + and hasattr(vllm_config.quant_config, "quant_description") + and any("norm.bias" in name for name in vllm_config.quant_config.quant_description) ): self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index ad2429d120e..a95de34eb14 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -35,3 +35,4 @@ import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa +import vllm_ascend.patch.worker.patch_quantization # noqa diff --git a/vllm_ascend/patch/worker/patch_quantization.py b/vllm_ascend/patch/worker/patch_quantization.py new file mode 100644 index 00000000000..7c1252f91c1 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_quantization.py @@ -0,0 +1,13 @@ +import vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe as ct_moe_module +import vllm.model_executor.layers.quantization.kernels.mixed_precision as mixed_precision_module +from vllm.platforms import PlatformEnum + +from vllm_ascend.quantization.compressed_tensors.schemes.wNa16 import AscendW4A16FusedMoEMethod +from vllm_ascend.quantization.kernels.mixed_precision.npu import AscendwNa16LinearKernel + +mixed_precision_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [AscendwNa16LinearKernel] + +ct_moe_module.CompressedTensorsWNA16MarlinMoEMethod.apply = AscendW4A16FusedMoEMethod.apply +ct_moe_module.CompressedTensorsWNA16MarlinMoEMethod.process_weights_after_loading = ( + AscendW4A16FusedMoEMethod.process_weights_after_loading +) diff --git a/vllm_ascend/quantization/compressed_tensors/__init__.py b/vllm_ascend/quantization/compressed_tensors/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/__init__.py b/vllm_ascend/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/wNa16.py b/vllm_ascend/quantization/compressed_tensors/schemes/wNa16.py new file mode 100644 index 00000000000..f9ace63637a --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/schemes/wNa16.py @@ -0,0 +1,154 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections.abc import Callable + +import torch +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.quantization.utils import pack_to_int32, unpack_from_int32 + + +class AscendW4A16FusedMoEMethod: + """FusedMoE method for Ascend W4A16.""" + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( + "Number of global experts mismatch (excluding redundancy)" + ) + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight_packed, + w2=layer.w2_weight_packed, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_offset=layer.w13_weight_offset, + w2_offset=layer.w2_weight_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int4_w4a16=True, + expert_map=expert_map, + log2phy=log2phy, + dynamic_eplb=dynamic_eplb, + mc2_mask=kwargs.get("mc2_mask"), + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Convert weights from Marlin format to Ascend NPU format. + + vllm create_weights weights: + - w13_weight_packed: [e, k//8, 2*n] (Marlin) + - w2_weight_packed: [e, n//8, k] + - w13_weight_scale: [e, num_groups, 2*n] + - w2_weight_scale: [e, num_groups, k] + + Needed by Ascend ops: + - w13_weight_packed: [e, 2*n, k//8] (NPU int4pack) + - w2_weight_packed: [e, k, n//8] + - w13_weight_scale: [e, num_groups, 2*n] + - w2_weight_scale: [e, num_groups, k] + """ + + num_bits = 4 + pack_factor = 8 + + layer.w13_weight_packed.data = layer.w13_weight_packed.data.transpose(1, 2).contiguous() + layer.w2_weight_packed.data = layer.w2_weight_packed.data.transpose(1, 2).contiguous() + + w13_shape = layer.w13_weight_packed.data.shape + w2_shape = layer.w2_weight_packed.data.shape + unpacked_w13_weight = ( + unpack_from_int32( + layer.w13_weight_packed.data.flatten(0, 1), + torch.Size([w13_shape[0] * w13_shape[1], w13_shape[2] * pack_factor]), + num_bits, + ) + .view(w13_shape[0], w13_shape[1], -1) + .transpose(1, 2) + .contiguous() + .int() + ) + unpacked_w2_weight = ( + unpack_from_int32( + layer.w2_weight_packed.data.flatten(0, 1), + torch.Size([w2_shape[0] * w2_shape[1], w2_shape[2] * pack_factor]), + num_bits, + ) + .view(w2_shape[0], w2_shape[1], -1) + .transpose(1, 2) + .contiguous() + .int() + ) + + layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight) + layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight) + + layer.w13_weight_scale.data = layer.w13_weight_scale.data.contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.contiguous() + + # Only symmetric quantization is supported, the offset is for quant_apply_mlp function branch + layer.w13_weight_offset = torch.nn.Parameter( + torch.zeros_like(layer.w13_weight_scale.data), + requires_grad=False, + ) + layer.w2_weight_offset = torch.nn.Parameter( + torch.zeros_like(layer.w2_weight_scale.data), + requires_grad=False, + ) diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index ea13811015c..482652db025 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -24,7 +24,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS, register_quantization_config from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, @@ -40,18 +39,11 @@ logger = init_logger(__name__) -# Remove the original compressed_tensors method to replace with our implementation -def _remove_quantization_method(): - if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS: - QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD) - - -_remove_quantization_method() - QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, "QuantizationArgs"] | None] -@register_quantization_config(COMPRESSED_TENSORS_METHOD) +# TODO: Remove all the AscendCompressedTensorsConfig implement when finishing all compressed-tensors schemes +# @register_quantization_config(COMPRESSED_TENSORS_METHOD) class AscendCompressedTensorsConfig(QuantizationConfig): """Config class for LLM-Compressor (compressed_tensors) quantization on Ascend. diff --git a/vllm_ascend/quantization/kernels/__init__.py b/vllm_ascend/quantization/kernels/__init__.py new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/vllm_ascend/quantization/kernels/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm_ascend/quantization/kernels/mixed_precision/__init__.py b/vllm_ascend/quantization/kernels/mixed_precision/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/quantization/kernels/mixed_precision/npu.py b/vllm_ascend/quantization/kernels/mixed_precision/npu.py new file mode 100644 index 00000000000..bd2743db0eb --- /dev/null +++ b/vllm_ascend/quantization/kernels/mixed_precision/npu.py @@ -0,0 +1,66 @@ +import torch +import torch_npu +from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( + MPLinearKernel, + MPLinearLayerConfig, +) +from vllm.scalar_type import scalar_types + +from vllm_ascend.quantization.utils import unpack_from_int32 + + +class AscendwNa16LinearKernel(MPLinearKernel): + @classmethod + def get_min_capability(cls) -> int: + return 0 + + @classmethod + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: + if not torch.npu.is_available(): + return False, "Ascend wNa16 only supported on NPU devices" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Get original shape before transpose + weight_shape = layer.weight_packed.data.shape + + pack_factor = 8 + num_bits = 4 + if self.config.weight_type in [scalar_types.uint4, scalar_types.uint4b8]: + num_bits = 4 + elif self.config.weight_type in [scalar_types.uint8, scalar_types.uint8b128]: + num_bits = 8 + + # Unpack from int32 to int8 (with int4 range) + unpacked_weight = unpack_from_int32( + weight=layer.weight_packed.data, + shape=torch.Size([weight_shape[0], weight_shape[1] * pack_factor]), + num_bits=num_bits, + packed_dim=1, + ) + + # Transpose: [n, k] -> [k, n] + unpacked_weight = unpacked_weight.transpose(0, 1).contiguous().int() + + # Repack to int32 using NPU int4 packing + layer.weight_packed.data = torch_npu.npu_convert_weight_to_int4pack(unpacked_weight) + + # Transpose scales and offsets: [n, num_groups] -> [num_groups, n] + layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1).contiguous() + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + output = torch_npu.npu_weight_quant_batchmatmul( + x=x, + weight=layer.weight_packed, + antiquant_scale=layer.weight_scale, + antiquant_offset=None, + antiquant_group_size=self.config.group_size, + bias=bias, + ) + return output diff --git a/vllm_ascend/quantization/methods/w4a16.py b/vllm_ascend/quantization/methods/w4a16.py index f30ff88e465..65a2a67244a 100644 --- a/vllm_ascend/quantization/methods/w4a16.py +++ b/vllm_ascend/quantization/methods/w4a16.py @@ -19,86 +19,17 @@ from typing import Any import torch -import torch_npu from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.quantization.utils import pack_to_int32, unpack_from_int32 from .base import AscendMoEScheme from .registry import register_scheme -def unpack_from_int32( - weight: torch.Tensor, - shape: torch.Size, - num_bits: int, - packed_dim: int = 1, -) -> torch.Tensor: - """Unpacks quantized weights from int32 format back to original bits. - - :param weight: The packed int32 tensor containing quantized weights - :param shape: Original shape to restore, defaults to None - :param num_bits: The number of bits used for quantization (<= 8) - :param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1 - :return: Unpacked tensor with int8 dtype after applying offset correction - """ - assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}." - assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}." - - pack_factor = 32 // num_bits - mask = (1 << num_bits) - 1 - - if packed_dim == 1: - unpacked_weight = torch.zeros( - (weight.shape[0], weight.shape[1] * pack_factor), - device=weight.device, - dtype=torch.int32, - ) - for i in range(pack_factor): - unpacked_weight[:, i::pack_factor] = (weight >> (num_bits * i)) & mask - original_row_size = int(shape[1]) - unpacked_weight = unpacked_weight[:, :original_row_size] - else: - unpacked_weight = torch.zeros( - (weight.shape[0] * pack_factor, weight.shape[1]), - device=weight.device, - dtype=torch.int32, - ) - for i in range(pack_factor): - unpacked_weight[i::pack_factor, :] = (weight >> (num_bits * i)) & mask - original_row_size = int(shape[0]) - unpacked_weight = unpacked_weight[:original_row_size, :] - - offset = pow(2, num_bits) // 2 - unpacked_weight = (unpacked_weight - offset).to(torch.int8) - - return unpacked_weight - - -def pack_to_int32(weight: torch.Tensor) -> torch.Tensor: - """Packs quantized weights into int32 format for storage. - - :param weight: The 3D tensor to pack, must be int8 or int32 dtype - :return: Packed tensor with int32 dtype optimized for storage - """ - assert weight.dim() == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." - assert weight.dtype in [torch.int8, torch.int32], ( - f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}." - ) - - if weight.dtype == torch.int32: - assert weight.shape[-1] % 8 == 0, "the last dim of weight needs to be divided by 8." - packed_weight = torch_npu.npu_convert_weight_to_int4pack(weight.flatten(0, 1)) - packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], -1) - else: - assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4." - packed_weight = weight.view(torch.int32).contiguous() - - return packed_weight - - @register_scheme("W4A16", "moe") class AscendW4A16FusedMoEMethod(AscendMoEScheme): """FusedMoE method for Ascend W4A16.""" diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 7c9570b4ac7..d4620e87c69 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -1,5 +1,6 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +18,8 @@ import json import os +import torch +import torch_npu from vllm.logger import init_logger @@ -145,3 +148,72 @@ def maybe_auto_detect_quantization(vllm_config) -> None: from vllm.config import VllmConfig as _VllmConfig vllm_config.quant_config = _VllmConfig._get_quantization_config(model_config, vllm_config.load_config) + + +def unpack_from_int32( + weight: torch.Tensor, + shape: torch.Size, + num_bits: int, + packed_dim: int = 1, +) -> torch.Tensor: + """Unpacks quantized weights from int32 format back to original bits. + + :param weight: The packed int32 tensor containing quantized weights + :param shape: Original shape to restore, defaults to None + :param num_bits: The number of bits used for quantization (<= 8) + :param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1 + :return: Unpacked tensor with int8 dtype after applying offset correction + """ + assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}." + assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}." + + pack_factor = 32 // num_bits + mask = (1 << num_bits) - 1 + + if packed_dim == 1: + unpacked_weight = torch.zeros( + (weight.shape[0], weight.shape[1] * pack_factor), + device=weight.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked_weight[:, i::pack_factor] = (weight >> (num_bits * i)) & mask + original_row_size = int(shape[1]) + unpacked_weight = unpacked_weight[:, :original_row_size] + else: + unpacked_weight = torch.zeros( + (weight.shape[0] * pack_factor, weight.shape[1]), + device=weight.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked_weight[i::pack_factor, :] = (weight >> (num_bits * i)) & mask + original_row_size = int(shape[0]) + unpacked_weight = unpacked_weight[:original_row_size, :] + + offset = pow(2, num_bits) // 2 + unpacked_weight = (unpacked_weight - offset).to(torch.int8) + + return unpacked_weight + + +def pack_to_int32(weight: torch.Tensor) -> torch.Tensor: + """Packs quantized weights into int32 format for storage. + + :param weight: The 3D tensor to pack, must be int8 or int32 dtype + :return: Packed tensor with int32 dtype optimized for storage + """ + assert weight.dim() == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." + assert weight.dtype in [torch.int8, torch.int32], ( + f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}." + ) + + if weight.dtype == torch.int32: + assert weight.shape[-1] % 8 == 0, "the last dim of weight needs to be divided by 8." + packed_weight = torch_npu.npu_convert_weight_to_int4pack(weight.flatten(0, 1)) + packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], -1) + else: + assert weight.shape[-1] % 4 == 0, "the last dim of weight needs to be divided by 4." + packed_weight = weight.view(torch.int32).contiguous() + + return packed_weight From a7203ae5776fc6972a5b691a4ff5c8a3f83d6041 Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Thu, 26 Feb 2026 12:55:05 +0000 Subject: [PATCH 2/8] Refactor compressed-tensors w8a8 and w4a8. Signed-off-by: menogrey <1299267905@qq.com> --- .../patch/worker/patch_quantization.py | 36 +++- .../compressed_tensors/__init__.py | 1 - .../compressed_tensors/schemes/w4a8.py | 200 ++++++++++++++++++ .../compressed_tensors/schemes/w8a8.py | 196 +++++++++++++++++ .../kernels/mixed_precision/npu.py | 59 ++++++ .../kernels/scaled_mm/__init__.py | 0 .../quantization/kernels/scaled_mm/npu.py | 156 ++++++++++++++ 7 files changed, 645 insertions(+), 3 deletions(-) create mode 100644 vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py create mode 100644 vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py create mode 100644 vllm_ascend/quantization/kernels/scaled_mm/__init__.py create mode 100644 vllm_ascend/quantization/kernels/scaled_mm/npu.py diff --git a/vllm_ascend/patch/worker/patch_quantization.py b/vllm_ascend/patch/worker/patch_quantization.py index 7c1252f91c1..004ffc96ee7 100644 --- a/vllm_ascend/patch/worker/patch_quantization.py +++ b/vllm_ascend/patch/worker/patch_quantization.py @@ -1,13 +1,45 @@ import vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe as ct_moe_module import vllm.model_executor.layers.quantization.kernels.mixed_precision as mixed_precision_module +import vllm.model_executor.layers.quantization.kernels.scaled_mm as scaled_mm_module from vllm.platforms import PlatformEnum +from vllm_ascend.quantization.compressed_tensors.schemes.w4a8 import ( + CompressedTensorsAscendW4A8DynamicFusedMoEMethod, +) +from vllm_ascend.quantization.compressed_tensors.schemes.w8a8 import ( + AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod, +) from vllm_ascend.quantization.compressed_tensors.schemes.wNa16 import AscendW4A16FusedMoEMethod -from vllm_ascend.quantization.kernels.mixed_precision.npu import AscendwNa16LinearKernel +from vllm_ascend.quantization.kernels.mixed_precision.npu import ( + AscendW4A8LinearKernel, + AscendwNa16LinearKernel, +) +from vllm_ascend.quantization.kernels.scaled_mm.npu import ( + AscendDynamicInt8ScaledMMLinearKernel, + AscendStaticInt8ScaledMMLinearKernel, +) -mixed_precision_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [AscendwNa16LinearKernel] +mixed_precision_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [ + AscendW4A8LinearKernel, + AscendwNa16LinearKernel, +] +scaled_mm_module._POSSIBLE_INT8_KERNELS[PlatformEnum.OOT] = [ + AscendDynamicInt8ScaledMMLinearKernel, + AscendStaticInt8ScaledMMLinearKernel, +] ct_moe_module.CompressedTensorsWNA16MarlinMoEMethod.apply = AscendW4A16FusedMoEMethod.apply ct_moe_module.CompressedTensorsWNA16MarlinMoEMethod.process_weights_after_loading = ( AscendW4A16FusedMoEMethod.process_weights_after_loading ) + +ct_moe_module.CompressedTensorsW8A8Int8MoEMethod.apply = AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod.apply +ct_moe_module.CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = ( + AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod.process_weights_after_loading +) + +ct_moe_module.CompressedTensorsW4A8Int8MoEMethod.__init__ = CompressedTensorsAscendW4A8DynamicFusedMoEMethod.__init__ +ct_moe_module.CompressedTensorsW4A8Int8MoEMethod.apply = CompressedTensorsAscendW4A8DynamicFusedMoEMethod.apply +ct_moe_module.CompressedTensorsW4A8Int8MoEMethod.process_weights_after_loading = ( + CompressedTensorsAscendW4A8DynamicFusedMoEMethod.process_weights_after_loading +) diff --git a/vllm_ascend/quantization/compressed_tensors/__init__.py b/vllm_ascend/quantization/compressed_tensors/__init__.py index 8b137891791..e69de29bb2d 100644 --- a/vllm_ascend/quantization/compressed_tensors/__init__.py +++ b/vllm_ascend/quantization/compressed_tensors/__init__.py @@ -1 +0,0 @@ - diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py b/vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py new file mode 100644 index 00000000000..af95f59512b --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py @@ -0,0 +1,200 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +import torch_npu +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, +) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsMoEMethod + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ops.fused_moe.experts_selector import select_experts + + +def _normalize_weight_strategy(strategy: Any) -> str: + if strategy is None: + return "group" + if hasattr(strategy, "value"): + strategy = strategy.value + if isinstance(strategy, str): + lowered = strategy.lower() + if "group" in lowered: + return "group" + if "channel" in lowered: + return "channel" + raise ValueError(f"Unsupported weight strategy: {strategy}") + + +def _process_scale_compressed_tensors(scale: torch.Tensor) -> torch.Tensor: + scale = scale.transpose(1, 2).to(torch.float32).contiguous() + scale_np = scale.cpu().numpy() + scale_np.dtype = np.uint32 + scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu() + return scale_uint64_tensor + + +def _update_bias_compressed_tensors(weight: torch.Tensor, scale: torch.Tensor, strategy: str) -> torch.Tensor: + group_num, k, n = weight.shape + scale = scale.transpose(1, 2).contiguous() + scale = scale.reshape(group_num, -1, n) + group_num, quantgroup_num, n = scale.shape + + if strategy == "group": + tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * scale.reshape( + [group_num, quantgroup_num, 1, n] + ) + tmp = tmp.reshape([group_num, k, n]) + return 8 * tmp.sum(axis=1) + if strategy == "channel": + return 8 * (weight.to(torch.float32) * scale).sum(axis=1) + raise ValueError(f"Unsupported weight strategy: {strategy}") + + +class CompressedTensorsAscendW4A8DynamicFusedMoEMethod: + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + CompressedTensorsMoEMethod.__init__(self, moe) + self.has_bias = self.moe.has_bias + self.weight_quant = weight_quant + self.input_quant = input_quant + + # Validate scheme: weights=W4 (channel or group), + # activations=dynamic TOKEN (A8) + + # Must be dynamic per-token activations + if input_quant.strategy != QuantizationStrategy.TOKEN or not input_quant.dynamic: + raise ValueError("W4A8-int MoE needs dynamic per-token activation quantization.") + + # Weight can be channel-wise (group_size=None) or group-wise + self.group_size = weight_quant.group_size if (weight_quant.group_size is not None) else -1 + if weight_quant.num_bits != 4: + raise ValueError("This method only supports 4-bit weights (num_bits=4).") + + self.static_input_scales = False # always dynamic per token + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( + "Number of global experts mismatch (excluding redundancy)" + ) + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + + if enable_force_load_balance: + random_matrix = torch.rand( + topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device + ) + topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) + + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=[layer.w13_weight], + w2=[layer.w2_weight], + w1_scale=[layer.w13_weight_scale], + w2_scale=[layer.w2_weight_scale], + w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None, + w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int4_w4a8=True, + expert_map=expert_map, + log2phy=log2phy, + dynamic_eplb=get_ascend_config().eplb_config.dynamic_eplb, + mc2_mask=kwargs.get("mc2_mask"), + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() + + weight_quant = getattr(self, "weight_quant", None) + strategy = _normalize_weight_strategy(weight_quant.strategy if weight_quant is not None else None) + w13_bias = _update_bias_compressed_tensors(layer.w13_weight.data, layer.w13_weight_scale.data, strategy) + w2_bias = _update_bias_compressed_tensors(layer.w2_weight.data, layer.w2_weight_scale.data, strategy) + + layer.w13_weight_scale.data = _process_scale_compressed_tensors(layer.w13_weight_scale.data) + layer.w2_weight_scale.data = _process_scale_compressed_tensors(layer.w2_weight_scale.data) + + w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + + def _pack_to_int32(weight: torch.Tensor) -> torch.Tensor: + return torch_npu.npu_quantize( + weight.to(torch.float32), + torch.tensor([1.0]).npu(), + None, + torch.quint4x2, + -1, + False, + ) + + # Accuracy problem in nz format + # layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) + # layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) + layer.w13_weight.data = _pack_to_int32(layer.w13_weight.data) + layer.w2_weight.data = _pack_to_int32(layer.w2_weight.data) diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py b/vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py new file mode 100644 index 00000000000..77f3a151276 --- /dev/null +++ b/vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py @@ -0,0 +1,196 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections.abc import Callable +from typing import Any + +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.flash_common3_context import get_flash_common3_context +from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + + +def scale_from_float_to_int64(scale): + """Convert float32 scale to int64 representation.""" + import numpy as np + + scale = torch.from_numpy( + np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), dtype=np.int32).astype(np.int64) + ).to(scale.device) + return scale + + +class AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod: + def process_weights_after_loading(self, layer): + self.dtype = get_current_vllm_config().model_config.dtype + layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() + layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() + # TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant` + # can only support weight nz. + layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1).to( + torch.bfloat16 + ) + layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(torch.float32) + if hasattr(layer, "w13_weight_offset"): + layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1).to( + torch.bfloat16 + ) + if hasattr(layer, "w2_weight_offset"): + layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1) + + layer.w13_weight_offset = torch.nn.Parameter( + torch.zeros_like(layer.w13_weight_scale.data), + requires_grad=False, + ) + layer.w2_weight_offset = torch.nn.Parameter( + torch.zeros_like(layer.w2_weight_scale.data), + requires_grad=False, + ) + + layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data) + layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data) + + if get_ascend_config().eplb_config.dynamic_eplb: + layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)] + layer.w2_weight_list = [weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)] + layer.w13_weight_scale_fp32_list = [ + weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0) + ] + layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)] + del layer.w13_weight + del layer.w2_weight + del layer.w13_weight_scale + del layer.w13_weight_scale_fp32 + del layer.w2_weight_scale + torch.npu.empty_cache() + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + pertoken_scale: Any | None = None, + **kwargs, + ) -> torch.Tensor: + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + if zero_expert_num == 0 or zero_expert_type is None: + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( + "Number of global experts mismatch (excluding redundancy)" + ) + + if get_ascend_config().multistream_overlap_gate: + fc3_context = get_flash_common3_context() + assert fc3_context is not None + topk_weights = fc3_context.topk_weights + topk_ids = fc3_context.topk_ids + else: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + assert topk_ids is not None + assert topk_weights is not None + if zero_expert_num > 0 and zero_expert_type is not None: + topk_ids, topk_weights, zero_expert_result = zero_experts_compute( + expert_indices=topk_ids, + expert_scales=topk_weights, + num_experts=global_num_experts, + zero_expert_type=zero_expert_type, + hidden_states=x, + ) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + random_matrix = torch.rand( + topk_ids.size(0), global_num_experts - global_redundant_expert_num, device=topk_ids.device + ) + topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype) + + assert topk_weights is not None + topk_weights = topk_weights.to(self.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + if get_ascend_config().eplb_config.dynamic_eplb: + w1 = layer.w13_weight_list + w1_scale = layer.w13_weight_scale_fp32_list + w2 = layer.w2_weight_list + w2_scale = layer.w2_weight_scale_list + else: + w1 = [layer.w13_weight] + w1_scale = [layer.w13_weight_scale_fp32] + w2 = [layer.w2_weight] + w2_scale = [layer.w2_weight_scale] + + fused_scale_flag = ( + get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 + and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1 + ) + final_hidden_states = moe_comm_method.fused_experts( + hidden_states=x, + pertoken_scale=pertoken_scale, + w1=w1, + w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale, + w2=w2, + w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int8_w8a8=True, + expert_map=expert_map, + log2phy=log2phy, + dynamic_eplb=get_ascend_config().eplb_config.dynamic_eplb, + mc2_mask=kwargs.get("mc2_mask"), + ) + if zero_expert_num > 0 and zero_expert_type is not None: + final_hidden_states += zero_expert_result + return final_hidden_states diff --git a/vllm_ascend/quantization/kernels/mixed_precision/npu.py b/vllm_ascend/quantization/kernels/mixed_precision/npu.py index bd2743db0eb..ca664605e9f 100644 --- a/vllm_ascend/quantization/kernels/mixed_precision/npu.py +++ b/vllm_ascend/quantization/kernels/mixed_precision/npu.py @@ -7,6 +7,7 @@ from vllm.scalar_type import scalar_types from vllm_ascend.quantization.utils import unpack_from_int32 +from vllm_ascend.utils import maybe_trans_nz class AscendwNa16LinearKernel(MPLinearKernel): @@ -64,3 +65,61 @@ def apply_weights( bias=bias, ) return output + + +class AscendW4A8LinearKernel(MPLinearKernel): + SUPPORTED_QUANT_TYPES = [scalar_types.int4] + + @classmethod + def get_min_capability(cls) -> int: + return 0 + + @classmethod + def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: + if not torch.npu.is_available(): + return False, "Ascend W4A8 only supported on NPU devices" + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return False, f"Unsupported quant type {c.weight_type}" + if c.full_weight_shape[0] % c.group_size != 0: + return ( + False, + f"Group size ({c.group_size}) does not evenly divide " + f"the number of input features ({c.full_weight_shape[0]})", + ) + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = getattr(layer, self.w_q_name).data + weight = weight.transpose(0, 1).contiguous() + weight = maybe_trans_nz(weight) + + # Pack int4 values (stored as int8 in [-8, 7]) into NPU int4pack format. + weight = torch_npu.npu_quantize( + weight.to(torch.float32), + torch.tensor([1.0]).npu(), + None, + torch.quint4x2, + -1, + False, + ) + getattr(layer, self.w_q_name).data = weight + + scale = getattr(layer, self.w_s_name).data + getattr(layer, self.w_s_name).data = scale.contiguous().to(torch.float32) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + weight, scale, _, _ = self._get_weight_params(layer) + output = torch_npu.npu_weight_quant_batchmatmul( + x=x, + weight=weight, + antiquant_scale=scale.to(x.dtype), + antiquant_offset=None, + antiquant_group_size=self.config.group_size, + bias=bias, + ) + return output diff --git a/vllm_ascend/quantization/kernels/scaled_mm/__init__.py b/vllm_ascend/quantization/kernels/scaled_mm/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/vllm_ascend/quantization/kernels/scaled_mm/npu.py b/vllm_ascend/quantization/kernels/scaled_mm/npu.py new file mode 100644 index 00000000000..ac574a130e7 --- /dev/null +++ b/vllm_ascend/quantization/kernels/scaled_mm/npu.py @@ -0,0 +1,156 @@ +import torch +import torch_npu +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, +) + +from vllm_ascend.utils import get_weight_prefetch_method, maybe_trans_nz + + +class AscendDynamicInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): + @classmethod + def is_supported(cls, compute_capability: int | None = None) -> tuple[bool, str | None]: + if not torch.npu.is_available(): + return False, "requires Ascend NPU." + return True, None + + @classmethod + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if c.is_static_input_scheme: + return False, "AscendDynamicInt8ScaledMMLinearKernel does not support static input quantization scheme." + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + # cast quantized weight tensors in NZ format for higher inference speed + layer.weight.data = maybe_trans_nz(layer.weight.data) + layer.weight_scale.data = layer.weight_scale.data.flatten() + # layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) + # layer.weight_offset.data = layer.weight_offset.data.flatten() + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x) + output = torch_npu.npu_quant_matmul( + quantized_x, + layer.weight, + layer.weight_scale, + pertoken_scale=pertoken_scale, + bias=bias, + output_dtype=x.dtype, + ) + return output + + +class AscendStaticInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): + @classmethod + def is_supported(cls, compute_capability: int | None = None) -> tuple[bool, str | None]: + if not torch.npu.is_available(): + return False, "requires Ascend NPU." + return True, None + + @classmethod + def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + if not c.is_static_input_scheme: + return ( + False, + "AscendStaticInt8ScaledMMLinearLayerConfig does not support dynamic input quantization scheme.", + ) + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight_scale.data = layer.weight_scale.data.to(torch.bfloat16) + layer.input_scale.data = layer.input_scale.data.to(torch.bfloat16) + expanding_factor = layer.weight.data.shape[1] + layer.aclnn_input_scale = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), requires_grad=False + ) + layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor), requires_grad=False + ) + + if layer.input_zero_point is None: + layer.input_zero_point = torch.nn.Parameter( + torch.zeros(1, dtype=torch.int8, device=layer.weight.device), requires_grad=False + ) + + layer.aclnn_input_offset = torch.nn.Parameter( + layer.input_zero_point.data.repeat(expanding_factor), requires_grad=False + ).to(layer.aclnn_input_scale.dtype) + + layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + layer.weight.data = maybe_trans_nz(layer.weight.data) + layer.weight_scale.data = torch.flatten(layer.weight_scale.data) + # layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + + deq_scale = layer.input_scale.data * layer.weight_scale.data + layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if x.dtype != torch.int8: + layer_cls_name = layer.__class__.__name__ + weight_prefetch_method = get_weight_prefetch_method() + # prefetch qkvo_proj.weight preprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + layer_cls_name=layer_cls_name, + weight=layer.weight, + start_flag=x, + ) + try: + quant_comm_config = layer._quant_comm_config + except AttributeError: + quant_comm_config = {} + comm_fn = quant_comm_config.get("communication_fn") + enable_flashcomm2_quant_comm = comm_fn is not None and ( + "o_proj" in layer.prefix or "out_proj" in layer.prefix + ) + if enable_flashcomm2_quant_comm: + quant_input_x = x.contiguous().view(-1, layer.aclnn_input_scale_reciprocal.size(0)) + quant_x = torch.ops.vllm.quantize( + quant_input_x, + layer.aclnn_input_scale, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + comm_input = quant_x.view(x.size(0), -1) + assert comm_fn is not None + x = comm_fn(comm_input) + else: + # quant + x = torch.ops.vllm.quantize( + x, + layer.aclnn_input_scale, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + + # prefetch qkvo_proj.weight postprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( + layer_cls_name=layer_cls_name, + stop_flag=x, + ) + + # quant_bias = layer.quant_bias if tp_rank == 0 else None + + quant_bias = bias + + output = torch_npu.npu_quant_matmul( + x, + layer.weight, + layer.deq_scale, + bias=quant_bias, + output_dtype=layer.params_dtype, + ) + return output From 9e27e4672268b94ef115702aa8bb53be542961a3 Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Fri, 27 Feb 2026 09:17:46 +0000 Subject: [PATCH 3/8] Clean up messy code Signed-off-by: menogrey <1299267905@qq.com> --- .../patch/worker/patch_quantization.py | 3 + .../compressed_tensors/schemes/w8a8.py | 43 +++++++------ .../kernels/mixed_precision/npu.py | 28 ++++----- .../quantization/kernels/scaled_mm/npu.py | 61 +++++++++---------- .../quantization/methods/w8a8_dynamic.py | 4 -- 5 files changed, 71 insertions(+), 68 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_quantization.py b/vllm_ascend/patch/worker/patch_quantization.py index 004ffc96ee7..2f209363085 100644 --- a/vllm_ascend/patch/worker/patch_quantization.py +++ b/vllm_ascend/patch/worker/patch_quantization.py @@ -33,6 +33,9 @@ AscendW4A16FusedMoEMethod.process_weights_after_loading ) +ct_moe_module.CompressedTensorsW8A8Int8MoEMethod.create_weights = ( + AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod.create_weights +) ct_moe_module.CompressedTensorsW8A8Int8MoEMethod.apply = AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod.apply ct_moe_module.CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = ( AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod.process_weights_after_loading diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py b/vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py index 77f3a151276..ce5c0c4272f 100644 --- a/vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py +++ b/vllm_ascend/quantization/compressed_tensors/schemes/w8a8.py @@ -22,6 +22,9 @@ import torch_npu from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + CompressedTensorsW8A8Int8MoEMethod, +) import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -41,7 +44,26 @@ def scale_from_float_to_int64(scale): return scale +_original_create_weights = CompressedTensorsW8A8Int8MoEMethod.create_weights + + class AscendCompressedTensorsW8A8Int8DynamicFusedMoEMethod: + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + _original_create_weights( + self, layer, num_experts, hidden_size, intermediate_size_per_partition, params_dtype, **extra_weight_attrs + ) + # Adapt for Ascend process, original create_weights use the float32 dtype. + layer.w13_weight_scale.data = layer.w13_weight_scale.data.to(params_dtype) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.to(params_dtype) + def process_weights_after_loading(self, layer): self.dtype = get_current_vllm_config().model_config.dtype layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() @@ -50,26 +72,9 @@ def process_weights_after_loading(self, layer): # can only support weight nz. layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) - layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1).to( - torch.bfloat16 - ) + layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(torch.float32) - if hasattr(layer, "w13_weight_offset"): - layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1) - layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1).to( - torch.bfloat16 - ) - if hasattr(layer, "w2_weight_offset"): - layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1) - - layer.w13_weight_offset = torch.nn.Parameter( - torch.zeros_like(layer.w13_weight_scale.data), - requires_grad=False, - ) - layer.w2_weight_offset = torch.nn.Parameter( - torch.zeros_like(layer.w2_weight_scale.data), - requires_grad=False, - ) + layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1) layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data) layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data) diff --git a/vllm_ascend/quantization/kernels/mixed_precision/npu.py b/vllm_ascend/quantization/kernels/mixed_precision/npu.py index ca664605e9f..60c991021db 100644 --- a/vllm_ascend/quantization/kernels/mixed_precision/npu.py +++ b/vllm_ascend/quantization/kernels/mixed_precision/npu.py @@ -23,8 +23,9 @@ def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight, scale, _, _ = self._get_weight_params(layer) # Get original shape before transpose - weight_shape = layer.weight_packed.data.shape + weight_shape = weight.data.shape pack_factor = 8 num_bits = 4 @@ -35,7 +36,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Unpack from int32 to int8 (with int4 range) unpacked_weight = unpack_from_int32( - weight=layer.weight_packed.data, + weight=weight.data, shape=torch.Size([weight_shape[0], weight_shape[1] * pack_factor]), num_bits=num_bits, packed_dim=1, @@ -45,10 +46,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: unpacked_weight = unpacked_weight.transpose(0, 1).contiguous().int() # Repack to int32 using NPU int4 packing - layer.weight_packed.data = torch_npu.npu_convert_weight_to_int4pack(unpacked_weight) + weight.data = torch_npu.npu_convert_weight_to_int4pack(unpacked_weight) # Transpose scales and offsets: [n, num_groups] -> [num_groups, n] - layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1).contiguous() + scale.data = scale.data.transpose(0, 1).contiguous() def apply_weights( self, @@ -56,10 +57,11 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + weight, scale, _, _ = self._get_weight_params(layer) output = torch_npu.npu_weight_quant_batchmatmul( x=x, - weight=layer.weight_packed, - antiquant_scale=layer.weight_scale, + weight=weight, + antiquant_scale=scale, antiquant_offset=None, antiquant_group_size=self.config.group_size, bias=bias, @@ -89,23 +91,21 @@ def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]: return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight = getattr(layer, self.w_q_name).data - weight = weight.transpose(0, 1).contiguous() - weight = maybe_trans_nz(weight) + weight, scale, _, _ = self._get_weight_params(layer) + + weight.data = maybe_trans_nz(weight.data.transpose(0, 1).contiguous()) # Pack int4 values (stored as int8 in [-8, 7]) into NPU int4pack format. - weight = torch_npu.npu_quantize( - weight.to(torch.float32), + weight.data = torch_npu.npu_quantize( + weight.data.to(torch.float32), torch.tensor([1.0]).npu(), None, torch.quint4x2, -1, False, ) - getattr(layer, self.w_q_name).data = weight - scale = getattr(layer, self.w_s_name).data - getattr(layer, self.w_s_name).data = scale.contiguous().to(torch.float32) + scale.data = scale.data.flatten().to(torch.float32) def apply_weights( self, diff --git a/vllm_ascend/quantization/kernels/scaled_mm/npu.py b/vllm_ascend/quantization/kernels/scaled_mm/npu.py index ac574a130e7..7feb1832cb9 100644 --- a/vllm_ascend/quantization/kernels/scaled_mm/npu.py +++ b/vllm_ascend/quantization/kernels/scaled_mm/npu.py @@ -22,12 +22,11 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() + w_q, w_s, _, _, _ = self._get_layer_params(layer) + w_q.data = w_q.data.transpose(0, 1).contiguous() # cast quantized weight tensors in NZ format for higher inference speed - layer.weight.data = maybe_trans_nz(layer.weight.data) - layer.weight_scale.data = layer.weight_scale.data.flatten() - # layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) - # layer.weight_offset.data = layer.weight_offset.data.flatten() + w_q.data = maybe_trans_nz(w_q.data) + w_s.data = w_s.data.flatten() def apply_weights( self, @@ -35,11 +34,12 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_layer_params(layer) quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x) output = torch_npu.npu_quant_matmul( quantized_x, - layer.weight, - layer.weight_scale, + w_q, + w_s, pertoken_scale=pertoken_scale, bias=bias, output_dtype=x.dtype, @@ -59,36 +59,36 @@ def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | No if not c.is_static_input_scheme: return ( False, - "AscendStaticInt8ScaledMMLinearLayerConfig does not support dynamic input quantization scheme.", + "AscendStaticInt8ScaledMMLinearKernel does not support dynamic input quantization scheme.", ) return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.weight_scale.data = layer.weight_scale.data.to(torch.bfloat16) - layer.input_scale.data = layer.input_scale.data.to(torch.bfloat16) - expanding_factor = layer.weight.data.shape[1] - layer.aclnn_input_scale = torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor), requires_grad=False - ) + w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer) + + w_s.data = w_s.data.to(layer.params_dtype) + i_s.data = i_s.data.to(layer.params_dtype) + expanding_factor = w_q.data.shape[1] + layer.aclnn_input_scale = torch.nn.Parameter(i_s.data.repeat(expanding_factor), requires_grad=False) layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor), requires_grad=False + i_s.data.repeat(expanding_factor), requires_grad=False ) - if layer.input_zero_point is None: - layer.input_zero_point = torch.nn.Parameter( - torch.zeros(1, dtype=torch.int8, device=layer.weight.device), requires_grad=False + if i_zp is None: + input_zero_point = torch.zeros(1, dtype=torch.int8, device=w_q.device) + layer.aclnn_input_offset = torch.nn.Parameter( + input_zero_point.repeat(expanding_factor), requires_grad=False + ).to(layer.aclnn_input_scale.dtype) + else: + layer.aclnn_input_offset = torch.nn.Parameter(i_zp.data.repeat(expanding_factor), requires_grad=False).to( + layer.aclnn_input_scale.dtype ) - layer.aclnn_input_offset = torch.nn.Parameter( - layer.input_zero_point.data.repeat(expanding_factor), requires_grad=False - ).to(layer.aclnn_input_scale.dtype) + w_q.data = w_q.data.transpose(0, 1).contiguous() + w_q.data = maybe_trans_nz(w_q.data) + w_s.data = torch.flatten(w_s.data) - layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() - layer.weight.data = maybe_trans_nz(layer.weight.data) - layer.weight_scale.data = torch.flatten(layer.weight_scale.data) - # layer.weight_offset.data = torch.flatten(layer.weight_offset.data) - - deq_scale = layer.input_scale.data * layer.weight_scale.data + deq_scale = i_s.data * w_s.data layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False) def apply_weights( @@ -97,6 +97,7 @@ def apply_weights( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + w_q, _, _, _, _ = self._get_layer_params(layer) if x.dtype != torch.int8: layer_cls_name = layer.__class__.__name__ weight_prefetch_method = get_weight_prefetch_method() @@ -104,7 +105,7 @@ def apply_weights( if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( layer_cls_name=layer_cls_name, - weight=layer.weight, + weight=w_q, start_flag=x, ) try: @@ -142,13 +143,11 @@ def apply_weights( stop_flag=x, ) - # quant_bias = layer.quant_bias if tp_rank == 0 else None - quant_bias = bias output = torch_npu.npu_quant_matmul( x, - layer.weight, + w_q, layer.deq_scale, bias=quant_bias, output_dtype=layer.params_dtype, diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index b150d1a5875..d1095b77688 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -101,8 +101,6 @@ def process_weights_after_loading(self, layer): # cast quantized weight tensors in NZ format for higher inference speed layer.weight.data = maybe_trans_nz(layer.weight.data) layer.weight_scale.data = layer.weight_scale.data.flatten() - layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) - layer.weight_offset.data = layer.weight_offset.data.flatten() @register_scheme("W8A8_DYNAMIC", "moe") @@ -278,9 +276,7 @@ def process_weights_after_loading(self, layer): layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(torch.float32) - layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1) - layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1) layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data) layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data) From 3b09da13daadc62f1e7796f3927bd37e6909947f Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Fri, 27 Feb 2026 09:38:55 +0000 Subject: [PATCH 4/8] Clean up legacy code after refactoring. Signed-off-by: menogrey <1299267905@qq.com> --- vllm_ascend/platform.py | 2 +- vllm_ascend/quantization/__init__.py | 8 +- .../quantization/compressed_tensors_config.py | 422 ------------------ vllm_ascend/quantization/methods/w4a8.py | 115 +---- .../quantization/methods/w8a8_static.py | 12 - 5 files changed, 4 insertions(+), 555 deletions(-) delete mode 100644 vllm_ascend/quantization/compressed_tensors_config.py diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 2d161a42eb2..2e49b412b7c 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -147,7 +147,7 @@ def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) - quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) if not is_310p(): - from vllm_ascend.quantization import AscendCompressedTensorsConfig, AscendModelSlimConfig # noqa: F401 + from vllm_ascend.quantization import AscendModelSlimConfig # noqa: F401 else: from vllm_ascend._310p.quantization import AscendModelSlimConfig310 # noqa: F401 diff --git a/vllm_ascend/quantization/__init__.py b/vllm_ascend/quantization/__init__.py index 1bf2912537a..7d13281960e 100644 --- a/vllm_ascend/quantization/__init__.py +++ b/vllm_ascend/quantization/__init__.py @@ -20,20 +20,16 @@ Supported quantization tools: - ModelSlim: Use AscendModelSlimConfig -- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig +- LLM-Compressor (compressed_tensors) Public API: -- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig +- Config classes: AscendModelSlimConfig - For scheme implementations, import from vllm_ascend.quantization.methods """ -# LLM-Compressor (compressed_tensors) quantization config -from .compressed_tensors_config import AscendCompressedTensorsConfig - # ModelSlim quantization config from .modelslim_config import AscendModelSlimConfig __all__ = [ "AscendModelSlimConfig", - "AscendCompressedTensorsConfig", ] diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py deleted file mode 100644 index 482652db025..00000000000 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ /dev/null @@ -1,422 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -"""LLM-Compressor (compressed_tensors) quantization configuration for Ascend.""" - -from typing import Any, Optional, cast - -import torch -from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy, QuantizationType -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - find_matched_target, - is_activation_quantization_format, - should_ignore_layer, -) -from vllm.model_executor.models.utils import WeightsMapper - -from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD - -from .methods import AscendLinearScheme, AscendMoEScheme - -logger = init_logger(__name__) - - -QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, "QuantizationArgs"] | None] - - -# TODO: Remove all the AscendCompressedTensorsConfig implement when finishing all compressed-tensors schemes -# @register_quantization_config(COMPRESSED_TENSORS_METHOD) -class AscendCompressedTensorsConfig(QuantizationConfig): - """Config class for LLM-Compressor (compressed_tensors) quantization on Ascend. - - This class adapts the compressed_tensors format to work with Ascend's - quantization implementations. - """ - - def __init__( - self, - target_scheme_map: dict[str, Any], - ignore: list[str], - quant_format: str, - config: dict[str, Any] | None = None, - ): - super().__init__() - self.ignore = ignore - self.quant_format = quant_format - # Map from [target -> scheme] - self.target_scheme_map = target_scheme_map - self.quant_description = config - - def get_name(self) -> str: - return "compressed-tensors" - - @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.int8, torch.float16, torch.bfloat16] - - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError('Ascend hardware dose not support "get_min_capability" feature.') - - @classmethod - def get_config_filenames(cls) -> list[str]: - return [] - - def _add_fused_moe_to_target_scheme_map(self): - """ - Helper function to update target_scheme_map - since linear layers get fused into FusedMoE - targeting 'Linear' needs to also match - FusedMoE modules. - """ - if "Linear" not in self.target_scheme_map or "FusedMoE" in self.target_scheme_map: - return - self.target_scheme_map["FusedMoE"] = self.target_scheme_map["Linear"] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "AscendCompressedTensorsConfig": - ignore: list[str] = cast(list[str], config.get("ignore", [])) - quant_format = cast(str, config.get("format")) - target_scheme_map = cls._quantization_scheme_map_from_config(config=config) - - return cls( - target_scheme_map=target_scheme_map, - ignore=ignore, - quant_format=quant_format, - config=config, - ) - - @classmethod - def _quantization_scheme_map_from_config(cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: - """Build target scheme map from config. - - :param config: The `quantization_config` dictionary from config.json - :return: A dictionary mapping target layer names to their corresponding - quantization_args for weights and input activations - """ - - target_scheme_map: dict[str, Any] = dict() - quant_format = cast(str, config.get("format")) - - config_groups = config.get("config_groups", dict()) - for _, quant_config in config_groups.items(): - targets = quant_config.get("targets") - for target in targets: - target_scheme_map[target] = {} - target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(quant_config.get("weights")) - - target_scheme_map[target]["input_activations"] = None - target_scheme_map[target]["format"] = quant_config.get("format") - format = target_scheme_map[target].get("format") - # If no per-config format defined, use global format in config - act_quant_format = ( - is_activation_quantization_format(format) - if format is not None - else is_activation_quantization_format(quant_format) - ) - input_activations = quant_config.get("input_activations") - if act_quant_format and input_activations is not None: - target_scheme_map[target]["input_activations"] = QuantizationArgs.model_validate( - quant_config.get("input_activations") - ) - return target_scheme_map - - def get_quant_method( - self, - layer: torch.nn.Module, - prefix: str, - ) -> Optional["QuantizeMethodBase"]: - from .method_adapters import AscendFusedMoEMethod, AscendLinearMethod - - if isinstance(layer, LinearBase): - layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD - # Get the scheme for this layer - linear_scheme = self._get_linear_scheme(layer=layer, layer_name=prefix) - - # Return unquantized method if no scheme found - if linear_scheme is None: - return UnquantizedLinearMethod() - - # Store scheme on layer for reference (optional, for debugging) - layer.scheme = linear_scheme - logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") - return AscendLinearMethod(linear_scheme) - - if isinstance(layer, FusedMoE): - # Delayed import to avoid circular import - from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod - - layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD - layer_name = prefix + ".0.gate_proj" - # Get the scheme for this layer - moe_scheme = self._get_moe_scheme(layer=layer, layer_name=layer_name) - - # Return unquantized method if no scheme found - if moe_scheme is None: - return AscendUnquantizedFusedMoEMethod(layer.moe_config) - - # Store scheme on layer for reference (optional, for debugging) - layer.scheme = moe_scheme - logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") - return AscendFusedMoEMethod(moe_scheme, layer.moe_config) - - return None - - def _get_linear_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendLinearScheme | None: - """Get the linear quantization scheme for a layer. - - Returns: - An AscendLinearScheme instance, or None if the layer - should use unquantized method. - """ - weight_quant, input_quant, format = self._get_quant_args(layer, layer_name) - if weight_quant is None: - return None - - scheme = self._create_scheme_for_layer_type( - weight_quant=weight_quant, - input_quant=input_quant, - format=format, - layer_type="linear", - ) - return cast(AscendLinearScheme, scheme) - - def _get_moe_scheme(self, layer: torch.nn.Module, layer_name: str | None = None) -> AscendMoEScheme | None: - """Get the MoE quantization scheme for a layer. - - Returns: - An AscendMoEScheme instance, or None if the layer - should use unquantized method. - """ - # Add FusedMoE to target scheme map if needed - self._add_fused_moe_to_target_scheme_map() - - weight_quant, input_quant, format = self._get_quant_args(layer, layer_name) - if weight_quant is None: - return None - - scheme = self._create_scheme_for_layer_type( - weight_quant=weight_quant, - input_quant=input_quant, - format=format, - layer_type="moe", - ) - return cast(AscendMoEScheme, scheme) - - def _get_quant_args( - self, layer: torch.nn.Module, layer_name: str | None = None - ) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"], str | None]: - """Extract quantization arguments for a layer. - - compressed-tensors supports non uniform in the following way: - - targets of config_groups: There can be N config_groups which each - have a quantization scheme. Each config_group has a list of targets - which can be a full layer_name, a regex for a layer_name, or - an nn.Module name. - - Detect whether a layer_name is found in any target and - use the quantization scheme corresponding to the matched target. - - Returns: - A tuple of (weight_quant, input_quant, format). weight_quant is - None if the layer should use unquantized method. - """ - scheme_dict = self.get_scheme_dict(layer, layer_name) - weight_quant = None - input_quant = None - format = None - if scheme_dict: - weight_quant = scheme_dict.get("weights") - input_quant = scheme_dict.get("input_activations") - format = scheme_dict.get("format") - - if weight_quant is None: - logger.warning_once( - "Acceleration for non-quantized schemes is " - "not supported by Compressed Tensors. " - "Falling back to UnquantizedLinearMethod" - ) - - return weight_quant, input_quant, format - - def get_scheme_dict( - self, layer: torch.nn.Module, layer_name: str | None = None - ) -> dict[str, QuantizationArgs | str | None] | None: - """ - Extract the QuantizationArgs for a given layer. - - Returns: - dict with { - "weights": QuantizationArgs, - "input_activations": QuantizationArgs | None, - "format": str | None - } | None - """ - if should_ignore_layer(layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping): - return None - - if self.target_scheme_map: - matched_target = find_matched_target( - layer_name=layer_name, - module=layer, - targets=self.target_scheme_map.keys(), - fused_mapping=self.packed_modules_mapping, - ) - scheme_dict = self.target_scheme_map[matched_target] - if scheme_dict.get("format") is None: - scheme_dict["format"] = self.quant_format - return scheme_dict - - return None - - def _create_scheme_for_layer_type( - self, - weight_quant: "QuantizationArgs", - input_quant: Optional["QuantizationArgs"], - format: str | None, - layer_type: str, - ) -> AscendLinearScheme | AscendMoEScheme: - """Create the appropriate Ascend scheme based on quantization args and layer type. - - Args: - weight_quant: Weight quantization arguments. - input_quant: Input activation quantization arguments. - format: Per-layer format, if defined. - layer_type: Type of layer ("linear" or "moe"). - - Returns: - An instance of the appropriate Ascend quantization scheme. - """ - from .methods import get_scheme_class - - # Determine the quantization type - quant_type = self._detect_quant_type(weight_quant, input_quant, format) - - # Get the scheme class from registry - scheme_cls = get_scheme_class(quant_type, layer_type) - if scheme_cls is None: - raise NotImplementedError( - f"No compressed-tensors compatible scheme was found for " - f"quant_type={quant_type}, layer_type={layer_type}." - ) - - return scheme_cls() - - def _detect_quant_type( - self, - weight_quant: "QuantizationArgs", - input_quant: Optional["QuantizationArgs"], - format: str | None, - ) -> str: - """Detect the quantization type from quantization arguments. - - Args: - weight_quant: Weight quantization arguments. - input_quant: Input activation quantization arguments. - format: Per-layer format, if defined. - - Returns: - A string representing the quantization type (e.g., "W8A8", "W8A8_DYNAMIC"). - """ - # use the per-layer format if defined, otherwise, use global format - format = format if format is not None else self.quant_format - act_quant_format = is_activation_quantization_format(format) - - if act_quant_format and input_quant is not None: - if self._is_static_tensor_w8a8(weight_quant, input_quant): - return "W8A8" - - if self._is_dynamic_token_w8a8(weight_quant, input_quant): - return "W8A8_DYNAMIC" - - if self._is_dynamic_token_w4a8(weight_quant, input_quant): - return "W4A8_DYNAMIC" - - if self._is_w4a16(weight_quant, input_quant): - return "W4A16" - - raise NotImplementedError("No compressed-tensors compatible quantization type was found.") - - def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool: - is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 - weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value - is_tensor = weight_strategy and input_quant.strategy == QuantizationStrategy.TENSOR.value - is_static = not weight_quant.dynamic and not input_quant.dynamic - is_symmetric = weight_quant.symmetric and input_quant.symmetric - - # Only symmetric input quantization supported. - # Only symmetric weight quantization supported. - return is_8_bits and is_tensor and is_symmetric and is_static - - def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool: - is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 - weight_strategy = weight_quant.strategy == QuantizationStrategy.CHANNEL.value - is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value - is_dynamic = not weight_quant.dynamic and input_quant.dynamic - is_symmetric = weight_quant.symmetric and input_quant.symmetric - - # Only symmetric input quantization supported. - # Only symmetric weight quantization supported. - return is_8_bits and is_token and is_symmetric and is_dynamic - - def _is_dynamic_token_w4a8(self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs) -> bool: - is_4_bits = weight_quant.num_bits == 4 - is_8_bits = input_quant.num_bits == 8 - weight_strategy = (weight_quant.strategy == QuantizationStrategy.CHANNEL.value) or ( - weight_quant.strategy == QuantizationStrategy.GROUP.value - ) - is_token = weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value - is_dynamic = not weight_quant.dynamic and input_quant.dynamic - is_symmetric = weight_quant.symmetric and input_quant.symmetric - - # Adapt for AscendW4A8DynamicFusedMoEMethod - assert self.quant_description is not None, "quant_description should not be None" - if weight_strategy: - self.quant_description["group_size"] = weight_quant.group_size if weight_quant.group_size else 0 - - self.quant_description["version"] = "0" - self.quant_description["ascend_quant_method"] = COMPRESSED_TENSORS_METHOD - self.quant_description["weight_strategy"] = str(weight_quant.strategy) - - # Only symmetric input quantization supported. - # Only symmetric weight quantization supported. - return is_4_bits and is_8_bits and is_token and is_symmetric and is_dynamic - - def _is_w4a16(self, weight_quant: "QuantizationArgs", input_quant: Optional["QuantizationArgs"]) -> bool: - # Confirm weights quantized. - if weight_quant is None: - return False - - # Confirm we have integer type. - if weight_quant.type != QuantizationType.INT: - return False - - input_quant_none = input_quant is None - is_4_bits = weight_quant.num_bits == 4 - is_group = weight_quant.strategy == QuantizationStrategy.GROUP.value - is_static = not weight_quant.dynamic - - return input_quant_none and is_4_bits and is_group and is_static - - def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): - self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map) - self.ignore = hf_to_vllm_mapper.apply_list(self.ignore) diff --git a/vllm_ascend/quantization/methods/w4a8.py b/vllm_ascend/quantization/methods/w4a8.py index 8a5ebca226a..1d373ed96c2 100644 --- a/vllm_ascend/quantization/methods/w4a8.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -28,7 +28,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz +from vllm_ascend.utils import maybe_trans_nz from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .registry import register_scheme @@ -195,10 +195,6 @@ def __init__(self): # NOTE: new quantize weights: 2 int4 pack into int8 self.new_quant_version = quant_version == "1.0.0" - self.quant_method = vllm_config.quant_config.quant_description.get("ascend_quant_method", "") - if self.quant_method == COMPRESSED_TENSORS_METHOD: - self.weight_strategy = vllm_config.quant_config.quant_description.get("weight_strategy", "group") - self.tp_size = 1 if vllm_config.parallel_config.enable_expert_parallel else self.ep_group.world_size self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb if self.new_quant_version and self.tp_size > 16: @@ -215,28 +211,6 @@ def __init__(self): def get_weight( self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype - ) -> dict[str, Any]: - if self.quant_method == COMPRESSED_TENSORS_METHOD: - return self.get_weight_compressed_tensors( - num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype - ) - else: - return self.get_weight_modelslim(num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype) - - def get_weight_compressed_tensors( - self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype - ) -> dict[str, Any]: - param_dict = {} - E = num_experts - H = hidden_sizes - IN = intermediate_size_per_partition - - param_dict["w13_weight"] = torch.empty(E, 2 * IN, H, dtype=torch.int8) - param_dict["w2_weight"] = torch.empty(E, H, IN, dtype=torch.int8) - return param_dict - - def get_weight_modelslim( - self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype ) -> dict[str, Any]: param_dict = {} if self.new_quant_version: @@ -254,38 +228,6 @@ def get_weight_modelslim( def get_dynamic_quant_param( self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype - ) -> dict[str, Any]: - if self.quant_method == COMPRESSED_TENSORS_METHOD: - return self.get_dynamic_quant_param_compressed_tensors( - num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype - ) - else: - return self.get_dynamic_quant_param_modelslim( - num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype - ) - - def get_dynamic_quant_param_compressed_tensors( - self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype - ) -> dict[str, Any]: - param_dict = {} - - E = num_experts - H = hidden_sizes - IN = intermediate_size_per_partition - g = self.group_size - - # Per-row scale columns - def _n_scale_cols(in_features: int) -> int: - return 1 if g <= 0 else (in_features // g) - - param_dict["w13_weight_scale"] = torch.empty(E, 2 * IN, _n_scale_cols(H), dtype=torch.bfloat16) - - param_dict["w2_weight_scale"] = torch.empty(E, H, _n_scale_cols(IN), dtype=torch.bfloat16) - - return param_dict - - def get_dynamic_quant_param_modelslim( - self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype ) -> dict[str, Any]: param_dict = {} param_dict["w13_weight_scale"] = torch.empty( @@ -447,61 +389,6 @@ def pack_to_int32(self, weight: torch.Tensor): ) def process_weights_after_loading(self, layer): - if self.quant_method == COMPRESSED_TENSORS_METHOD: - self.process_weights_after_loading_compressed_tensors(layer) - else: - self.process_weights_after_loading_modelslim(layer) - - def process_weights_after_loading_compressed_tensors(self, layer): - layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() - layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() - - def process_scale_compressed_tensors(scale: torch.Tensor): - scale = scale.transpose(1, 2).to(torch.float32).contiguous() - scale_np = scale.cpu().numpy() - scale_np.dtype = np.uint32 - scale_uint64_tensor = torch.from_numpy(scale_np.astype(np.int64)).npu() - return scale_uint64_tensor - - def update_bias_compressed_tensors(weight: torch.Tensor, scale: torch.Tensor, strategy: str): - group_num, k, n = weight.shape - scale = scale.transpose(1, 2).contiguous() - scale = scale.reshape(group_num, -1, n) - group_num, quantgroup_num, n = scale.shape - - bias = None - if strategy == "group": - tmp = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * scale.reshape( - [group_num, quantgroup_num, 1, n] - ) - tmp = tmp.reshape([group_num, k, n]) - bias = 8 * tmp.sum(axis=1) - elif strategy == "channel": - bias = 8 * (weight.to(torch.float32) * scale).sum(axis=1) - else: - raise ValueError(f"Unsupported weight strategy: {strategy}") - return bias - - w13_bias = update_bias_compressed_tensors( - layer.w13_weight.data, layer.w13_weight_scale.data, self.weight_strategy - ) - w2_bias = update_bias_compressed_tensors(layer.w2_weight.data, layer.w2_weight_scale.data, self.weight_strategy) - - layer.w13_weight_scale.data = process_scale_compressed_tensors(layer.w13_weight_scale.data) - layer.w2_weight_scale.data = process_scale_compressed_tensors(layer.w2_weight_scale.data) - - w13_scale_bias = torch.nn.Parameter(w13_bias, requires_grad=False) - layer.register_parameter("w13_scale_bias", w13_scale_bias) - w2_scale_bias = torch.nn.Parameter(w2_bias, requires_grad=False) - layer.register_parameter("w2_scale_bias", w2_scale_bias) - - # Accuracy problem in nz format - # layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) - # layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) - layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) - layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) - - def process_weights_after_loading_modelslim(self, layer): layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() diff --git a/vllm_ascend/quantization/methods/w8a8_static.py b/vllm_ascend/quantization/methods/w8a8_static.py index 47ffc201004..64720104113 100644 --- a/vllm_ascend/quantization/methods/w8a8_static.py +++ b/vllm_ascend/quantization/methods/w8a8_static.py @@ -21,7 +21,6 @@ import torch_npu from vllm_ascend.utils import ( - COMPRESSED_TENSORS_METHOD, get_weight_prefetch_method, maybe_trans_nz, ) @@ -123,13 +122,6 @@ def apply( quant_bias = layer.quant_bias if tp_rank == 0 else None - try: - ascend_quant_method = layer.ascend_quant_method - except AttributeError: - ascend_quant_method = "" - if ascend_quant_method == COMPRESSED_TENSORS_METHOD: - quant_bias = bias - output = torch_npu.npu_quant_matmul( x, layer.weight, @@ -155,7 +147,3 @@ def process_weights_after_loading(self, layer): layer.weight.data = maybe_trans_nz(layer.weight.data) layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) - ascend_quant_method = getattr(layer, "ascend_quant_method", "") - if ascend_quant_method == COMPRESSED_TENSORS_METHOD: - deq_scale = layer.input_scale.data * layer.weight_scale.data - layer.deq_scale = torch.nn.Parameter(deq_scale, requires_grad=False) From 64b91cdf81263bd6479dab00c528e7ef981a3ad5 Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Fri, 27 Feb 2026 10:30:44 +0000 Subject: [PATCH 5/8] Update and resolve conflicts. Signed-off-by: menogrey <1299267905@qq.com> --- vllm_ascend/patch/worker/patch_quantization.py | 7 +++---- vllm_ascend/quantization/kernels/mixed_precision/npu.py | 2 +- vllm_ascend/quantization/kernels/scaled_mm/npu.py | 2 +- vllm_ascend/quantization/modelslim_config.py | 4 +--- vllm_ascend/quantization/utils.py | 6 ++++-- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_quantization.py b/vllm_ascend/patch/worker/patch_quantization.py index 2f209363085..5fbd7aafc15 100644 --- a/vllm_ascend/patch/worker/patch_quantization.py +++ b/vllm_ascend/patch/worker/patch_quantization.py @@ -1,6 +1,5 @@ +import vllm.model_executor.kernels.linear as linear_module import vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe as ct_moe_module -import vllm.model_executor.layers.quantization.kernels.mixed_precision as mixed_precision_module -import vllm.model_executor.layers.quantization.kernels.scaled_mm as scaled_mm_module from vllm.platforms import PlatformEnum from vllm_ascend.quantization.compressed_tensors.schemes.w4a8 import ( @@ -19,11 +18,11 @@ AscendStaticInt8ScaledMMLinearKernel, ) -mixed_precision_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [ +linear_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [ AscendW4A8LinearKernel, AscendwNa16LinearKernel, ] -scaled_mm_module._POSSIBLE_INT8_KERNELS[PlatformEnum.OOT] = [ +linear_module._POSSIBLE_INT8_KERNELS[PlatformEnum.OOT] = [ AscendDynamicInt8ScaledMMLinearKernel, AscendStaticInt8ScaledMMLinearKernel, ] diff --git a/vllm_ascend/quantization/kernels/mixed_precision/npu.py b/vllm_ascend/quantization/kernels/mixed_precision/npu.py index 60c991021db..525ea0af922 100644 --- a/vllm_ascend/quantization/kernels/mixed_precision/npu.py +++ b/vllm_ascend/quantization/kernels/mixed_precision/npu.py @@ -1,6 +1,6 @@ import torch import torch_npu -from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( +from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import ( MPLinearKernel, MPLinearLayerConfig, ) diff --git a/vllm_ascend/quantization/kernels/scaled_mm/npu.py b/vllm_ascend/quantization/kernels/scaled_mm/npu.py index 7feb1832cb9..bc745c4349a 100644 --- a/vllm_ascend/quantization/kernels/scaled_mm/npu.py +++ b/vllm_ascend/quantization/kernels/scaled_mm/npu.py @@ -1,6 +1,6 @@ import torch import torch_npu -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( +from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import ( Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index f88a3a90286..cbd8c9a52c0 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -38,13 +38,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding from vllm.model_executor.models.utils import WeightsMapper +from vllm_ascend.quantization.utils import MODELSLIM_CONFIG_FILENAME from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from .methods import get_scheme_class -# The config filename that ModelSlim generates after quantizing a model. -MODELSLIM_CONFIG_FILENAME = "quant_model_description.json" - logger = init_logger(__name__) # key: model_type diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index d4620e87c69..6cff2898a63 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -18,16 +18,18 @@ import json import os + import torch import torch_npu - from vllm.logger import init_logger -from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD logger = init_logger(__name__) +# The config filename that ModelSlim generates after quantizing a model. +MODELSLIM_CONFIG_FILENAME = "quant_model_description.json" + def detect_quantization_method(model_path: str) -> str | None: """Auto-detect the quantization method from model directory files. From ccd38f18f58a9bff6fa88e30b20b7b7bddd70772 Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Sat, 28 Feb 2026 04:00:36 +0000 Subject: [PATCH 6/8] Fix vllm==0.16.0 error. Signed-off-by: menogrey <1299267905@qq.com> --- .../patch/worker/patch_quantization.py | 33 ++++++++++++++----- .../kernels/mixed_precision/npu.py | 17 +++++++--- .../quantization/kernels/scaled_mm/npu.py | 17 +++++++--- 3 files changed, 48 insertions(+), 19 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_quantization.py b/vllm_ascend/patch/worker/patch_quantization.py index 5fbd7aafc15..43c8fa83298 100644 --- a/vllm_ascend/patch/worker/patch_quantization.py +++ b/vllm_ascend/patch/worker/patch_quantization.py @@ -1,4 +1,3 @@ -import vllm.model_executor.kernels.linear as linear_module import vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe as ct_moe_module from vllm.platforms import PlatformEnum @@ -17,15 +16,31 @@ AscendDynamicInt8ScaledMMLinearKernel, AscendStaticInt8ScaledMMLinearKernel, ) +from vllm_ascend.utils import vllm_version_is -linear_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [ - AscendW4A8LinearKernel, - AscendwNa16LinearKernel, -] -linear_module._POSSIBLE_INT8_KERNELS[PlatformEnum.OOT] = [ - AscendDynamicInt8ScaledMMLinearKernel, - AscendStaticInt8ScaledMMLinearKernel, -] +if vllm_version_is("0.16.0"): + import vllm.model_executor.layers.quantization.kernels.mixed_precision as mixed_precision_module + import vllm.model_executor.layers.quantization.kernels.scaled_mm as scaled_mm_module + + mixed_precision_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [ + AscendW4A8LinearKernel, + AscendwNa16LinearKernel, + ] + scaled_mm_module._POSSIBLE_INT8_KERNELS[PlatformEnum.OOT] = [ + AscendDynamicInt8ScaledMMLinearKernel, + AscendStaticInt8ScaledMMLinearKernel, + ] +else: + import vllm.model_executor.kernels.linear as linear_module + + linear_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [ + AscendW4A8LinearKernel, + AscendwNa16LinearKernel, + ] + linear_module._POSSIBLE_INT8_KERNELS[PlatformEnum.OOT] = [ + AscendDynamicInt8ScaledMMLinearKernel, + AscendStaticInt8ScaledMMLinearKernel, + ] ct_moe_module.CompressedTensorsWNA16MarlinMoEMethod.apply = AscendW4A16FusedMoEMethod.apply ct_moe_module.CompressedTensorsWNA16MarlinMoEMethod.process_weights_after_loading = ( diff --git a/vllm_ascend/quantization/kernels/mixed_precision/npu.py b/vllm_ascend/quantization/kernels/mixed_precision/npu.py index 525ea0af922..4bf96cf8fa4 100644 --- a/vllm_ascend/quantization/kernels/mixed_precision/npu.py +++ b/vllm_ascend/quantization/kernels/mixed_precision/npu.py @@ -1,13 +1,20 @@ import torch import torch_npu -from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import ( - MPLinearKernel, - MPLinearLayerConfig, -) from vllm.scalar_type import scalar_types from vllm_ascend.quantization.utils import unpack_from_int32 -from vllm_ascend.utils import maybe_trans_nz +from vllm_ascend.utils import maybe_trans_nz, vllm_version_is + +if vllm_version_is("0.16.0"): + from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( + MPLinearKernel, + MPLinearLayerConfig, + ) +else: + from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import ( + MPLinearKernel, + MPLinearLayerConfig, + ) class AscendwNa16LinearKernel(MPLinearKernel): diff --git a/vllm_ascend/quantization/kernels/scaled_mm/npu.py b/vllm_ascend/quantization/kernels/scaled_mm/npu.py index bc745c4349a..ae60625f41f 100644 --- a/vllm_ascend/quantization/kernels/scaled_mm/npu.py +++ b/vllm_ascend/quantization/kernels/scaled_mm/npu.py @@ -1,11 +1,18 @@ import torch import torch_npu -from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import ( - Int8ScaledMMLinearKernel, - Int8ScaledMMLinearLayerConfig, -) -from vllm_ascend.utils import get_weight_prefetch_method, maybe_trans_nz +from vllm_ascend.utils import get_weight_prefetch_method, maybe_trans_nz, vllm_version_is + +if vllm_version_is("0.16.0"): + from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, + ) +else: + from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import ( + Int8ScaledMMLinearKernel, + Int8ScaledMMLinearLayerConfig, + ) class AscendDynamicInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): From 9b93d715f42e4cb4ecb5a51de08f90b72ab5c71a Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Sat, 28 Feb 2026 08:03:36 +0000 Subject: [PATCH 7/8] Fix mypy error. Signed-off-by: menogrey <1299267905@qq.com> --- vllm_ascend/patch/worker/patch_quantization.py | 4 ++-- vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py | 2 +- vllm_ascend/quantization/kernels/mixed_precision/npu.py | 2 +- vllm_ascend/quantization/kernels/scaled_mm/npu.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/patch/worker/patch_quantization.py b/vllm_ascend/patch/worker/patch_quantization.py index 43c8fa83298..2bbc4241a7d 100644 --- a/vllm_ascend/patch/worker/patch_quantization.py +++ b/vllm_ascend/patch/worker/patch_quantization.py @@ -19,8 +19,8 @@ from vllm_ascend.utils import vllm_version_is if vllm_version_is("0.16.0"): - import vllm.model_executor.layers.quantization.kernels.mixed_precision as mixed_precision_module - import vllm.model_executor.layers.quantization.kernels.scaled_mm as scaled_mm_module + import vllm.model_executor.layers.quantization.kernels.mixed_precision as mixed_precision_module # type: ignore + import vllm.model_executor.layers.quantization.kernels.scaled_mm as scaled_mm_module # type: ignore mixed_precision_module._POSSIBLE_KERNELS[PlatformEnum.OOT] = [ AscendW4A8LinearKernel, diff --git a/vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py b/vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py index af95f59512b..fdbf682d1bb 100644 --- a/vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py +++ b/vllm_ascend/quantization/compressed_tensors/schemes/w4a8.py @@ -81,7 +81,7 @@ def __init__( layer_name: str | None = None, ): CompressedTensorsMoEMethod.__init__(self, moe) - self.has_bias = self.moe.has_bias + self.has_bias = self.moe.has_bias # type: ignore self.weight_quant = weight_quant self.input_quant = input_quant diff --git a/vllm_ascend/quantization/kernels/mixed_precision/npu.py b/vllm_ascend/quantization/kernels/mixed_precision/npu.py index 4bf96cf8fa4..8fc28ee8277 100644 --- a/vllm_ascend/quantization/kernels/mixed_precision/npu.py +++ b/vllm_ascend/quantization/kernels/mixed_precision/npu.py @@ -6,7 +6,7 @@ from vllm_ascend.utils import maybe_trans_nz, vllm_version_is if vllm_version_is("0.16.0"): - from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( + from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # type: ignore MPLinearKernel, MPLinearLayerConfig, ) diff --git a/vllm_ascend/quantization/kernels/scaled_mm/npu.py b/vllm_ascend/quantization/kernels/scaled_mm/npu.py index ae60625f41f..4acc52a583f 100644 --- a/vllm_ascend/quantization/kernels/scaled_mm/npu.py +++ b/vllm_ascend/quantization/kernels/scaled_mm/npu.py @@ -4,7 +4,7 @@ from vllm_ascend.utils import get_weight_prefetch_method, maybe_trans_nz, vllm_version_is if vllm_version_is("0.16.0"): - from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( + from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # type: ignore Int8ScaledMMLinearKernel, Int8ScaledMMLinearLayerConfig, ) From 75e9cd153bf4810ec077621638a563767ded6d94 Mon Sep 17 00:00:00 2001 From: menogrey <1299267905@qq.com> Date: Sat, 28 Feb 2026 10:15:41 +0000 Subject: [PATCH 8/8] Fix UT error and add e2e w4a16 testcase. Signed-off-by: menogrey <1299267905@qq.com> --- .github/workflows/misc/model_list.json | 1 + .../multicard/2-cards/test_quantization.py | 88 ++++++++---------- tests/ut/quantization/test_quant_utils.py | 87 ++++++++++++++++++ tests/ut/quantization/test_w4a16.py | 90 +------------------ 4 files changed, 127 insertions(+), 139 deletions(-) diff --git a/.github/workflows/misc/model_list.json b/.github/workflows/misc/model_list.json index eb2fdd2c2c9..34c01f813e4 100644 --- a/.github/workflows/misc/model_list.json +++ b/.github/workflows/misc/model_list.json @@ -12,6 +12,7 @@ "BAAI/bge-small-en-v1.5", "BAAI/kernel_meta", "ByteDance-Seed/BAGEL-7B-MoT", + "cpatonn-mirror/Qwen3-30B-A3B-Thinking-2507-AWQ-4bit", "DeepSeek-ai/DeepSeek-OCR", "DevQuasar/deepseek-ai.DeepSeek-V3.2-BF16", "Eco-Tech/DeepSeek-V3.1-w8a8-mtp-QuaRot", diff --git a/tests/e2e/multicard/2-cards/test_quantization.py b/tests/e2e/multicard/2-cards/test_quantization.py index da45628bc37..80ec499235b 100644 --- a/tests/e2e/multicard/2-cards/test_quantization.py +++ b/tests/e2e/multicard/2-cards/test_quantization.py @@ -21,67 +21,53 @@ from tests.e2e.conftest import VllmRunner -def test_qwen2_5_w8a8_external_quantized_tp2(): - example_prompts = [ - "The president of the United States is", - ] - max_tokens = 5 - with VllmRunner( - "neuralmagic/Qwen2.5-3B-quantized.w8a8", - tensor_parallel_size=2, - cudagraph_capture_sizes=[1, 2, 4, 8], - max_model_len=4096, - gpu_memory_utilization=0.8, - ) as vllm_model: - vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) +TEST_CASES = [ + pytest.param( + "neuralmagic/Qwen2.5-3B-quantized.w8a8", + [ + "The president of the United States is the head of state and", + ], + id="dense-w8a8", + ), + pytest.param( + "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w8a8", + [ + "The president of the United States is the head of state and", + ], + id="moe-w8a8-dynamic", + ), + pytest.param( + "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w4a8", + [ + "The president of the United States is the head of state and", + ], + id="moe-w4a8-dynamic", + ), + pytest.param( + "cpatonn-mirror/Qwen3-30B-A3B-Thinking-2507-AWQ-4bit", + [ + "The president of the United States is the head of state and", + ], + id="moe-w4a16-dynamic", + ), +] - golden_results = [ - 'The president of the United States is the head of state and', - ] - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") - - -def test_qwen3_moe_w8a8_dynamic_llm_compressor(): +@pytest.mark.parametrize("model_id, golden_results", TEST_CASES) +def test_compressed_tensors_tp2(model_id, golden_results): example_prompts = [ "The president of the United States is", ] max_tokens = 5 with VllmRunner( - "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w8a8", - tensor_parallel_size=2, - max_model_len=4096, - gpu_memory_utilization=0.8, + model_id, + max_model_len=4096, + tensor_parallel_size=2, + cudagraph_capture_sizes=[1, 2, 4, 8], + gpu_memory_utilization=0.8, ) as vllm_model: vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) - golden_results = [ - 'The president of the United States is the head of state and', - ] - - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") - -def test_qwen3_moe_w4a8_dynamic_llm_compressor(): - example_prompts = [ - "The president of the United States is", - ] - max_tokens = 5 - with VllmRunner( - "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w4a8", - tensor_parallel_size=2, - max_model_len=4096, - gpu_memory_utilization=0.8, - ) as vllm_model: - vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) - - golden_results = [ - 'The president of the United States is the head of state and', - ] - for i in range(len(vllm_output)): assert golden_results[i] == vllm_output[i][1] print(f"Generated text: {vllm_output[i][1]!r}") diff --git a/tests/ut/quantization/test_quant_utils.py b/tests/ut/quantization/test_quant_utils.py index f148342e504..c50b7e74dd9 100644 --- a/tests/ut/quantization/test_quant_utils.py +++ b/tests/ut/quantization/test_quant_utils.py @@ -3,12 +3,15 @@ import os import tempfile from unittest.mock import MagicMock, patch +import torch from tests.ut.base import TestBase from vllm_ascend.quantization.modelslim_config import MODELSLIM_CONFIG_FILENAME from vllm_ascend.quantization.utils import ( detect_quantization_method, maybe_auto_detect_quantization, + pack_to_int32, + unpack_from_int32, ) from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD @@ -180,3 +183,87 @@ def test_no_detection_emits_no_log(self, mock_detect): maybe_auto_detect_quantization(vllm_config) self.assertIsNone(vllm_config.model_config.quantization) + + +class TestUnpackFromInt32(TestBase): + + def test_unpack_from_int32_packed_dim_1(self): + weight = torch.tensor([[305419896, -1420531520]], dtype=torch.int32) + shape = torch.Size([1, 8]) + num_bits = 4 + + result = unpack_from_int32(weight, shape, num_bits, packed_dim=1) + + self.assertEqual(result.dtype, torch.int8) + self.assertEqual(result.shape, shape) + + def test_unpack_from_int32_packed_dim_0(self): + weight = torch.tensor([[305419896], [-1420531520]], dtype=torch.int32) + shape = torch.Size([8, 1]) + num_bits = 4 + + result = unpack_from_int32(weight, shape, num_bits, packed_dim=0) + + self.assertEqual(result.dtype, torch.int8) + self.assertEqual(result.shape, shape) + + def test_unpack_from_int32_assertions(self): + with self.assertRaises(AssertionError): + weight = torch.tensor([[1, 2]], dtype=torch.int64) + unpack_from_int32(weight, torch.Size([8, 1]), 4) + + with self.assertRaises(AssertionError): + weight = torch.tensor([[1, 2]], dtype=torch.int32) + unpack_from_int32(weight, torch.Size([8, 1]), 16) + + +class TestPackToInt32(TestBase): + + @patch( + "vllm_ascend.quantization.utils.torch_npu.npu_convert_weight_to_int4pack" + ) + def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack): + mock_npu_convert_weight_to_int4pack.return_value = torch.zeros( + (2, 4), dtype=torch.int32) + + weight = torch.zeros((2, 8, 16), dtype=torch.int8) + result = pack_to_int32(weight) + + self.assertEqual(result.dtype, torch.int32) + mock_npu_convert_weight_to_int4pack.assert_not_called() + + self.assertEqual(result.shape, torch.Size([2, 8, 4])) + + @patch( + "vllm_ascend.quantization.utils.torch_npu.npu_convert_weight_to_int4pack" + ) + def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack): + + def mock_convert_weight(weight): + return weight + + mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight + weight = torch.zeros((2, 8, 8), dtype=torch.int32) + result = pack_to_int32(weight) + + self.assertEqual(result.dtype, torch.int32) + self.assertEqual(result.shape, weight.shape) + + def test_pack_to_int32_assertion_dim(self): + with self.assertRaises(AssertionError): + weight = torch.zeros((8, 8), dtype=torch.int8) + pack_to_int32(weight) + + def test_pack_to_int32_assertion_dtype(self): + with self.assertRaises(AssertionError): + weight = torch.zeros((2, 8, 8), dtype=torch.float32) + pack_to_int32(weight) + + def test_pack_to_int32_assertion_divisible(self): + with self.assertRaises(AssertionError): + weight = torch.zeros((2, 8, 7), dtype=torch.int32) + pack_to_int32(weight) + + with self.assertRaises(AssertionError): + weight = torch.zeros((2, 8, 7), dtype=torch.int8) + pack_to_int32(weight) diff --git a/tests/ut/quantization/test_w4a16.py b/tests/ut/quantization/test_w4a16.py index adf4f706dd2..6ded45d9940 100644 --- a/tests/ut/quantization/test_w4a16.py +++ b/tests/ut/quantization/test_w4a16.py @@ -3,93 +3,7 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod, - pack_to_int32, - unpack_from_int32) - - -class TestUnpackFromInt32(TestBase): - - def test_unpack_from_int32_packed_dim_1(self): - weight = torch.tensor([[305419896, -1420531520]], dtype=torch.int32) - shape = torch.Size([1, 8]) - num_bits = 4 - - result = unpack_from_int32(weight, shape, num_bits, packed_dim=1) - - self.assertEqual(result.dtype, torch.int8) - self.assertEqual(result.shape, shape) - - def test_unpack_from_int32_packed_dim_0(self): - weight = torch.tensor([[305419896], [-1420531520]], dtype=torch.int32) - shape = torch.Size([8, 1]) - num_bits = 4 - - result = unpack_from_int32(weight, shape, num_bits, packed_dim=0) - - self.assertEqual(result.dtype, torch.int8) - self.assertEqual(result.shape, shape) - - def test_unpack_from_int32_assertions(self): - with self.assertRaises(AssertionError): - weight = torch.tensor([[1, 2]], dtype=torch.int64) - unpack_from_int32(weight, torch.Size([8, 1]), 4) - - with self.assertRaises(AssertionError): - weight = torch.tensor([[1, 2]], dtype=torch.int32) - unpack_from_int32(weight, torch.Size([8, 1]), 16) - - -class TestPackToInt32(TestBase): - - @patch( - "vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack" - ) - def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack): - mock_npu_convert_weight_to_int4pack.return_value = torch.zeros( - (2, 4), dtype=torch.int32) - - weight = torch.zeros((2, 8, 16), dtype=torch.int8) - result = pack_to_int32(weight) - - self.assertEqual(result.dtype, torch.int32) - mock_npu_convert_weight_to_int4pack.assert_not_called() - - self.assertEqual(result.shape, torch.Size([2, 8, 4])) - - @patch( - "vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack" - ) - def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack): - - def mock_convert_weight(weight): - return weight - - mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight - weight = torch.zeros((2, 8, 8), dtype=torch.int32) - result = pack_to_int32(weight) - - self.assertEqual(result.dtype, torch.int32) - self.assertEqual(result.shape, weight.shape) - - def test_pack_to_int32_assertion_dim(self): - with self.assertRaises(AssertionError): - weight = torch.zeros((8, 8), dtype=torch.int8) - pack_to_int32(weight) - - def test_pack_to_int32_assertion_dtype(self): - with self.assertRaises(AssertionError): - weight = torch.zeros((2, 8, 8), dtype=torch.float32) - pack_to_int32(weight) - - def test_pack_to_int32_assertion_divisible(self): - with self.assertRaises(AssertionError): - weight = torch.zeros((2, 8, 7), dtype=torch.int32) - pack_to_int32(weight) - - with self.assertRaises(AssertionError): - weight = torch.zeros((2, 8, 7), dtype=torch.int8) - pack_to_int32(weight) +from vllm_ascend.quantization.methods.w4a16 import AscendW4A16FusedMoEMethod class TestAscendW4A16FusedMoEMethod(TestBase): @@ -219,7 +133,7 @@ def build_layer(self): return layer @patch( - "vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack" + "vllm_ascend.quantization.utils.torch_npu.npu_convert_weight_to_int4pack" ) def test_process_weights_after_loading_with_transpose( self, mock_npu_convert_weight_to_int4pack):