diff --git a/tests/e2e/multicard/2-cards/test_quantization.py b/tests/e2e/multicard/2-cards/test_quantization.py index c3ae5f0bcc0..641d7f65289 100644 --- a/tests/e2e/multicard/2-cards/test_quantization.py +++ b/tests/e2e/multicard/2-cards/test_quantization.py @@ -16,63 +16,40 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/basic_correctness/test_basic_correctness.py # +import pytest from tests.e2e.conftest import VllmRunner - -def test_qwen2_5_w8a8_external_quantized_tp2(): - example_prompts = [ - "The president of the United States is", - ] - max_tokens = 5 - with VllmRunner( +TEST_CASES = [ + pytest.param( "neuralmagic/Qwen2.5-3B-quantized.w8a8", - tensor_parallel_size=2, - cudagraph_capture_sizes=[1, 2, 4, 8], - max_model_len=4096, - gpu_memory_utilization=0.8, - ) as vllm_model: - vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) - - golden_results = [ - "The president of the United States is the head of state and", - ] - - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") - - -def test_qwen3_moe_w8a8_dynamic_llm_compressor(): - example_prompts = [ - "The president of the United States is", - ] - max_tokens = 5 - with VllmRunner( + id="dense-w8a8", + ), + pytest.param( "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w8a8", - tensor_parallel_size=2, - max_model_len=4096, - gpu_memory_utilization=0.8, - ) as vllm_model: - vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) - - golden_results = [ - "The president of the United States is the head of state and", - ] - - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") + id="moe-w8a8-dynamic", + ), + pytest.param( + "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w4a8", + id="moe-w4a8-dynamic", + ), + pytest.param( + "billy800/Qwen3-30B-A3B-Instruct-2507-AWQ", + id="moe-awq-4bit", + ), +] -def test_qwen3_moe_w4a8_dynamic_llm_compressor(): +@pytest.mark.parametrize("model_id", TEST_CASES) +def test_quantization_tp2(model_id): example_prompts = [ "The president of the United States is", ] max_tokens = 5 with VllmRunner( - "vllm-ascend/Qwen3-30B-A3B-Instruct-2507-quantized.w4a8", + model_id, tensor_parallel_size=2, + cudagraph_capture_sizes=[1, 2, 4, 8], max_model_len=4096, gpu_memory_utilization=0.8, ) as vllm_model: diff --git a/tests/ut/quantization/test_w4a16_awq.py b/tests/ut/quantization/test_w4a16_awq.py new file mode 100644 index 00000000000..0fa91d6d509 --- /dev/null +++ b/tests/ut/quantization/test_w4a16_awq.py @@ -0,0 +1,318 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from unittest.mock import patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.awq_config import AWQConfig +from vllm_ascend.quantization.methods.w4a16_awq import (AscendW4A16AWQFusedMoEMethod, + AscendW4A16AWQLinearMethod, + _unpack_qzero_from_int32, + _unpack_weight_from_int32) + + +class TestAWQConfig(TestBase): + """Test AWQConfig class.""" + + def test_awq_config_init(self): + """Test AWQConfig initialization with valid parameters.""" + config = AWQConfig( + weight_bits=4, + group_size=128, + zero_point=True, + modules_to_not_convert=["lm_head"], + ) + + self.assertEqual(config.weight_bits, 4) + self.assertEqual(config.group_size, 128) + self.assertTrue(config.zero_point) + self.assertEqual(config.modules_to_not_convert, ["lm_head"]) + self.assertEqual(config.pack_factor, 8) + + def test_awq_config_invalid_weight_bits(self): + """Test AWQConfig raises error for non-4-bit weight quantization.""" + with self.assertRaises(ValueError) as context: + AWQConfig(weight_bits=8, group_size=128, zero_point=True) + + self.assertIn("only 4-bit weight quantization is supported", str(context.exception)) + + def test_awq_config_from_config(self): + """Test AWQConfig from_config method.""" + config_dict = { + "w_bit": 4, + "q_group_size": 128, + "zero_point": True, + "modules_to_not_convert": ["lm_head"], + } + + config = AWQConfig.from_config(config_dict) + + self.assertEqual(config.weight_bits, 4) + self.assertEqual(config.group_size, 128) + self.assertTrue(config.zero_point) + + +class TestAscendW4A16AWQLinearMethod(TestBase): + """Test AscendW4A16AWQLinearMethod class.""" + + def setUp(self): + super().setUp() + self.quant_config = AWQConfig( + weight_bits=4, + group_size=128, + zero_point=True, + ) + self.quant_method = AscendW4A16AWQLinearMethod(self.quant_config) + + def test_init(self): + """Test AscendW4A16AWQLinearMethod initialization.""" + self.assertEqual(self.quant_method.pack_factor, 8) + self.assertEqual(self.quant_method.group_size, 128) + + def test_process_weights_after_loading(self): + """Test process_weights_after_loading converts weights correctly.""" + layer = torch.nn.Module() + hidden_size = 512 + out_features = 1024 + pack_factor = 8 + group_size = 128 + + # Original vLLM AWQ format weights + num_groups = hidden_size // group_size + layer.qweight = torch.nn.Parameter( + torch.randint(0, 100, (hidden_size, out_features // pack_factor), dtype=torch.int32), + requires_grad=False + ) + layer.qzeros = torch.nn.Parameter( + torch.randint(0, 100, (num_groups, out_features // pack_factor), dtype=torch.int32), + requires_grad=False + ) + layer.scales = torch.nn.Parameter( + torch.ones((num_groups, out_features), dtype=torch.bfloat16), + requires_grad=False + ) + + # Process weights + self.quant_method.process_weights_after_loading(layer) + + # Verify qweight shape is unchanged and contiguous + self.assertEqual(layer.qweight.shape, (hidden_size, out_features // pack_factor)) + self.assertTrue(layer.qweight.data.is_contiguous()) + + # Verify qzeros is unpacked from (num_groups, out//pack) to (num_groups, out), bfloat16 + self.assertEqual(layer.qzeros.shape, (num_groups, out_features)) + self.assertEqual(layer.qzeros.dtype, torch.bfloat16) + self.assertTrue(layer.qzeros.data.is_contiguous()) + + # Verify parameters require no gradient + self.assertFalse(layer.qweight.requires_grad) + self.assertFalse(layer.scales.requires_grad) + self.assertFalse(layer.qzeros.requires_grad) + + def _build_layer(self, hidden_size: int, out_features: int) -> torch.nn.Module: + """Build a post-process_weights_after_loading mock linear layer.""" + group_size = self.quant_method.group_size + pack_factor = self.quant_method.pack_factor + layer = torch.nn.Module() + layer.qweight = torch.nn.Parameter( + torch.randint(0, 100, (hidden_size, out_features // pack_factor), dtype=torch.int32), + requires_grad=False, + ) + layer.scales = torch.nn.Parameter( + torch.ones((hidden_size // group_size, out_features), dtype=torch.bfloat16), + requires_grad=False, + ) + layer.qzeros = torch.nn.Parameter( + torch.zeros((hidden_size // group_size, out_features), dtype=torch.bfloat16), + requires_grad=False, + ) + return layer + + @patch("vllm_ascend.quantization.methods.w4a16_awq.torch_npu.npu_weight_quant_batchmatmul") + def test_apply(self, mock_npu_matmul): + """Test apply method calls npu_weight_quant_batchmatmul.""" + batch_size = 2 + seq_len = 8 + hidden_size = 512 + out_features = 1024 + + mock_output = torch.randn(batch_size, seq_len, out_features, dtype=torch.float32) + mock_npu_matmul.return_value = mock_output + + layer = self._build_layer(hidden_size, out_features) + x = torch.randn(batch_size, seq_len, hidden_size, dtype=torch.bfloat16) + + result = self.quant_method.apply(layer, x) + + mock_npu_matmul.assert_called_once() + self.assertEqual(result.shape, (batch_size, seq_len, out_features)) + + @patch("vllm_ascend.quantization.methods.w4a16_awq.torch_npu.npu_weight_quant_batchmatmul") + def test_apply_with_bias(self, mock_npu_matmul): + """Test apply method handles bias correctly.""" + batch_size = 1 + seq_len = 1 + hidden_size = 256 + out_features = 512 + + mock_output = torch.randn(batch_size, seq_len, out_features, dtype=torch.float32) + mock_npu_matmul.return_value = mock_output + + layer = self._build_layer(hidden_size, out_features) + x = torch.randn(batch_size, seq_len, hidden_size, dtype=torch.bfloat16) + bias = torch.randn(out_features, dtype=torch.bfloat16) + + # Call apply with bias + result = self.quant_method.apply(layer, x, bias) + + # Verify result is returned and bias is converted to float + self.assertIsNotNone(result) + call_kwargs = mock_npu_matmul.call_args.kwargs + self.assertEqual(call_kwargs["bias"].dtype, torch.float32) + + +class TestAscendW4A16AWQFusedMoEMethod(TestBase): + """Test AscendW4A16AWQFusedMoEMethod class.""" + + def setUp(self): + super().setUp() + self.quant_config = AWQConfig( + weight_bits=4, + group_size=128, + zero_point=True, + ) + self.quant_method = AscendW4A16AWQFusedMoEMethod(self.quant_config) + + def test_init(self): + """Test AscendW4A16AWQFusedMoEMethod initialization.""" + self.assertEqual(self.quant_method.pack_factor, 8) + self.assertEqual(self.quant_method.group_size, 128) + + def test_get_weight(self): + """Test get_weight returns correctly shaped weight tensors.""" + num_experts = 4 + intermediate = 512 + hidden = 256 + + result = self.quant_method.get_weight(num_experts, intermediate, hidden, torch.bfloat16) + + self.assertIn("w13_qweight", result) + self.assertIn("w2_qweight", result) + self.assertEqual(result["w13_qweight"].shape, (num_experts, hidden, 2 * intermediate // 8)) + self.assertEqual(result["w2_qweight"].shape, (num_experts, intermediate, hidden // 8)) + self.assertEqual(result["w13_qweight"].dtype, torch.int32) + self.assertEqual(result["w2_qweight"].dtype, torch.int32) + + def test_get_dynamic_quant_param(self): + """Test get_dynamic_quant_param returns correctly shaped scale/zero tensors.""" + num_experts = 4 + intermediate = 512 + hidden = 256 + group_size = 128 + + result = self.quant_method.get_dynamic_quant_param(num_experts, intermediate, hidden, torch.bfloat16) + + num_groups_w13 = hidden // group_size + num_groups_w2 = intermediate // group_size + + self.assertEqual(result["w13_scales"].shape, (num_experts, num_groups_w13, intermediate * 2)) + self.assertEqual(result["w2_scales"].shape, (num_experts, num_groups_w2, hidden)) + self.assertEqual(result["w13_qzeros"].shape, (num_experts, num_groups_w13, 2 * intermediate // 8)) + self.assertEqual(result["w2_qzeros"].shape, (num_experts, num_groups_w2, hidden // 8)) + self.assertEqual(result["w13_qzeros"].dtype, torch.int32) + self.assertEqual(result["w2_qzeros"].dtype, torch.int32) + + +class TestUnpackQzeroFromInt32(TestBase): + """Test unpack_qzero_from_int32 function for AWQ zero-points.""" + + def test_unpack_qzero_from_int32_linear_layer(self): + """Test unpacking zero-points for linear layer.""" + weight = torch.tensor([[305419896, -1420531520]], dtype=torch.int32) + param_dtype = torch.bfloat16 + + result = _unpack_qzero_from_int32(weight, param_dtype, pack_factor=8, is_moe_layer=False) + + # (1, 2) packed → (1, 16) unpacked (2 elements × 8 nibbles each) + self.assertEqual(result.shape, (1, 16)) + self.assertEqual(result.dtype, param_dtype) + self.assertTrue(result.is_contiguous()) + + def test_unpack_qzero_from_int32_moe_layer(self): + """Test unpacking zero-points for MoE layer.""" + weight = torch.tensor([[[305419896, -1420531520]]], dtype=torch.int32) + param_dtype = torch.bfloat16 + + result = _unpack_qzero_from_int32(weight, param_dtype, pack_factor=8, is_moe_layer=True) + + # (1, 1, 2) packed → (1, 1, 16) unpacked (2 elements × 8 nibbles each) + self.assertEqual(result.shape, (1, 1, 16)) + self.assertEqual(result.dtype, param_dtype) + self.assertTrue(result.is_contiguous()) + + def test_unpack_qzero_from_int32_unsigned_to_signed(self): + """Test unsigned int4 [0,15] to signed int4 [-8,7] conversion.""" + weight = torch.tensor([[0, 1, 7, 8, 9, 10, 15, 0]], dtype=torch.int32) + param_dtype = torch.bfloat16 + + result = _unpack_qzero_from_int32(weight, param_dtype, pack_factor=8, is_moe_layer=False) + + # Each int32 element unpacks to 8 nibbles; element k's lowest nibble lands at index k*8. + self.assertEqual(result[0, 0].item(), 8) # element 0: 0 -> -(0-8) = 8 + self.assertEqual(result[0, 8].item(), 7) # element 1: 1 -> -(1-8) = 7 + self.assertEqual(result[0, 24].item(), 0) # element 3: 8 -> -(8-8) = 0 (zero point) + self.assertEqual(result[0, 48].item(), -7) # element 6: 15 -> -(15-8) = -7 + + +class TestUnpackWeightFromInt32(TestBase): + """Test unpack_weight_from_int32 function for AWQ weights.""" + + def test_unpack_weight_from_int32_basic(self): + """Test unpacking weights with XOR transformation.""" + weight = torch.tensor([[305419896, -1420531520]], dtype=torch.int32) + + result = _unpack_weight_from_int32(weight, pack_factor=8) + + # Output shape is unchanged — repacking stays within the same int32 layout + self.assertEqual(result.shape, weight.shape) + self.assertEqual(result.dtype, torch.int32) + self.assertTrue(result.is_contiguous()) + + def test_unpack_weight_from_int32_xor_transformation(self): + """Test XOR 0x88888888 transformation is applied.""" + weight = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32) + + result = _unpack_weight_from_int32(weight, pack_factor=8) + + # All-zero input → repack loop produces all-zero weight_tmp → XOR with + # 0x88888888 makes every int32 element 0x88888888 = -2004318072 (signed int32). + self.assertEqual(result[0, 0].item(), -2004318072) # 0x88888888 as int32 + + def test_unpack_weight_from_int32_contiguous(self): + """Test output is contiguous.""" + weight = torch.randint(0, 100, (16, 8), dtype=torch.int32) + + result = _unpack_weight_from_int32(weight, pack_factor=8) + + self.assertTrue(result.is_contiguous()) + + +if __name__ == "__main__": + import unittest + unittest.main() \ No newline at end of file diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 0ba64ee424c..199453b5e61 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -36,6 +36,7 @@ # isort: off from vllm_ascend.utils import ( ASCEND_QUANTIZATION_METHOD, + AWQ_QUANTIZATION_METHOD, COMPILATION_PASS_KEY, COMPRESSED_TENSORS_METHOD, AscendDeviceType, @@ -101,7 +102,7 @@ class NPUPlatform(Platform): device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" dispatch_key: str = "PrivateUse1" - supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD] + supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD, AWQ_QUANTIZATION_METHOD, COMPRESSED_TENSORS_METHOD] def is_sleep_mode_available(self) -> bool: return True @@ -148,6 +149,7 @@ def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) - quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) if not is_310p(): + from vllm_ascend.quantization.awq_config import AWQConfig # noqa: F401 from vllm_ascend.quantization import AscendCompressedTensorsConfig, AscendModelSlimConfig # noqa: F401 else: from vllm_ascend._310p.quantization import AscendModelSlimConfig310 # noqa: F401 diff --git a/vllm_ascend/quantization/__init__.py b/vllm_ascend/quantization/__init__.py index 575c352634b..750b140f337 100644 --- a/vllm_ascend/quantization/__init__.py +++ b/vllm_ascend/quantization/__init__.py @@ -24,11 +24,13 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from .awq_config import AWQConfig from .compressed_tensors_config import AscendCompressedTensorsConfig from .modelslim_config import AscendModelSlimConfig __all__ = [ "AscendModelSlimConfig", + "AWQConfig", "AscendCompressedTensorsConfig", ] @@ -42,4 +44,8 @@ def __getattr__(name: str) -> Any: from .compressed_tensors_config import AscendCompressedTensorsConfig return AscendCompressedTensorsConfig + if name == "AWQConfig": + from .awq_config import AWQConfig + + return AWQConfig raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_ascend/quantization/awq_config.py b/vllm_ascend/quantization/awq_config.py new file mode 100644 index 00000000000..4f04ad94641 --- /dev/null +++ b/vllm_ascend/quantization/awq_config.py @@ -0,0 +1,118 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from typing import Any, Union + +import torch +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization import register_quantization_config +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped + +from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod +from vllm_ascend.utils import AWQ_QUANTIZATION_METHOD + +from .method_adapters import AscendFusedMoEMethod +from .methods import get_scheme_class +from .methods.w4a16_awq import AscendW4A16AWQLinearMethod + + +@register_quantization_config(AWQ_QUANTIZATION_METHOD) +class AWQConfig(QuantizationConfig): + """AWQ quantization config for Ascend NPU. + + Replaces vLLM's native AWQ config to route linear and MoE layers through + Ascend-specific scheme implementations (AscendW4A16AWQLinearMethod, + AscendW4A16AWQFusedMoEMethod) . + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + modules_to_not_convert: list[str] | None = None, + quant_config: dict[str, Any] | None = None, + ): + self.quant_description = quant_config if quant_config is not None else {} + super().__init__() + + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.modules_to_not_convert = modules_to_not_convert or [] + + if self.weight_bits != 4: + raise ValueError( + f"Currently, only 4-bit weight quantization is supported for AWQ, but got {self.weight_bits} bits." + ) + self.pack_factor = 32 // self.weight_bits + + def get_name(self) -> str: + return AWQ_QUANTIZATION_METHOD + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError("Ascend hardware does not support 'get_min_capability' feature.") + + @staticmethod + def get_config_filenames() -> list[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or(config, ["modules_to_not_convert"], None) + return cls(weight_bits, group_size, zero_point, modules_to_not_convert, config) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None: + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): + return AscendUnquantizedLinearMethod() + return AscendW4A16AWQLinearMethod(self) + + elif isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + skip_with_substr=True, + ): + return AscendUnquantizedFusedMoEMethod(layer.moe_config) + scheme_cls = get_scheme_class("W4A16_AWQ", "moe") + if scheme_cls is None: + raise NotImplementedError(f"W4A16_AWQ moe scheme not found for layer {prefix}") + return AscendFusedMoEMethod(scheme_cls(self), layer.moe_config) + + return None diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py index cd68ddcdd10..05e7185882a 100644 --- a/vllm_ascend/quantization/method_adapters.py +++ b/vllm_ascend/quantization/method_adapters.py @@ -213,6 +213,10 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: + # Merge quant_method's weight_attrs into extra_weight_attrs + quant_weight_attrs = getattr(self.quant_method, "weight_attrs", {}) + extra_weight_attrs = {**quant_weight_attrs, **extra_weight_attrs} + weight_param = self.quant_method.get_weight( num_experts, intermediate_size_per_partition, hidden_size, params_dtype ) @@ -222,7 +226,7 @@ def create_weights( set_weight_attrs(param, extra_weight_attrs) extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - per_group_param = ["weight_scale_second", "weight_offset_second", "scale_bias"] + ( + per_group_param = ["weight_scale_second", "weight_offset_second", "scale_bias", "qzeros", "scales"] + ( ["weight_scale", "weight_offset"] if hasattr(self.quant_method, "group_size") and self.quant_method.group_size > 0 else [] @@ -230,6 +234,7 @@ def create_weights( dynamic_quant_param = self.quant_method.get_dynamic_quant_param( num_experts, intermediate_size_per_partition, hidden_size, params_dtype ) + for param_key, param_value in dynamic_quant_param.items(): param = torch.nn.Parameter(param_value, requires_grad=False) layer.register_parameter(param_key, param) diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index 59c75a05863..e74b58a4c67 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -42,6 +42,7 @@ from .w4a4_mxfp4 import AscendW4A4MXFP4DynamicFusedMoEMethod, AscendW4A4MXFP4DynamicLinearMethod from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod from .w4a16 import AscendW4A16FusedMoEMethod +from .w4a16_awq import AscendW4A16AWQFusedMoEMethod, AscendW4A16AWQLinearMethod from .w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod from .w8a8_pdmix import AscendW8A8PDMixFusedMoeMethod, AscendW8A8PDMixLinearMethod @@ -81,6 +82,8 @@ def is_mx_quant_type(instance: Any) -> bool: "AscendW4A8DynamicLinearMethod", "AscendW4A8DynamicFusedMoEMethod", "AscendW4A16FusedMoEMethod", + "AscendW4A16AWQLinearMethod", + "AscendW4A16AWQFusedMoEMethod", "AscendW4A4FlatQuantDynamicLinearMethod", "AscendW4A4LaosDynamicLinearMethod", "AscendFAQuantAttentionMethod", diff --git a/vllm_ascend/quantization/methods/w4a16_awq.py b/vllm_ascend/quantization/methods/w4a16_awq.py new file mode 100644 index 00000000000..c5024537d50 --- /dev/null +++ b/vllm_ascend/quantization/methods/w4a16_awq.py @@ -0,0 +1,331 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import torch +import torch_npu +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.quantization.awq import AWQLinearMethod + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import _EXTRA_CTX +from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input + +from .base import AscendMoEScheme, QuantType +from .registry import register_scheme + +if TYPE_CHECKING: + from vllm_ascend.quantization.awq_config import AWQConfig + +# Bit shift pattern for unpacking 4-bit values from int32, see +# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py +REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def _unpack_qzero_from_int32( + weight: torch.Tensor, + param_dtype: torch.dtype, + pack_factor: int = 8, + is_moe_layer: bool = False, +) -> torch.Tensor: + """Unpack and convert AWQ zero-points (qzeros) from int32 to target dtype. + + :param weight: Packed int32 tensor containing zero-points + :param param_dtype: Target dtype (e.g., bfloat16) + :param pack_factor: Number of 4-bit values per int32 (default: 8) + :param is_moe_layer: Whether this is for MoE layer (default: False) + + :return: Unpacked and converted zero-points tensor + """ + weight_list = [] + + for i in range(pack_factor): + shift_num = REVERSE_AWQ_PACK_ORDER[i] * 4 + weight_list.append((weight.reshape(-1, 1) >> shift_num) & 0xF) + + if is_moe_layer: + weight = torch.cat(weight_list, dim=-1).reshape(weight.shape[0], weight.shape[1], -1) + else: + weight = torch.cat(weight_list, dim=-1).reshape(weight.shape[0], -1) + + # Convert unsigned int4 [0,15] to signed int4 [-8,7] + weight = -(weight - 8) + return weight.to(param_dtype).contiguous() + + +def _unpack_weight_from_int32( + weight: torch.Tensor, + pack_factor: int = 8, +) -> torch.Tensor: + """Unpack and convert AWQ weights (qweight) from int32 to NPU format. + + :param weight: Packed int32 tensor containing quantized weights + :param pack_factor: Number of 4-bit values per int32 (default: 8) + + :return: Unpacked and NPU-formatted weight tensor + """ + weight_tmp = torch.zeros_like(weight) + for i in range(pack_factor): + shift_num = REVERSE_AWQ_PACK_ORDER[i] * 4 + weight_tmp.bitwise_or_(((weight >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i))) + weight_tmp.bitwise_xor_(0x88888888) + return weight_tmp.contiguous() + + +class AscendW4A16AWQLinearMethod(AWQLinearMethod): + """Linear method for Ascend W4A16 AWQ quantization.""" + + def __init__(self, quant_config: "AWQConfig"): + self.quant_config = quant_config + self.pack_factor = self.quant_config.pack_factor + self.group_size = self.quant_config.group_size + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter( + _unpack_qzero_from_int32( + weight=layer.qzeros.data, + param_dtype=layer.scales.data.dtype, + pack_factor=self.pack_factor, + ), + requires_grad=False, + ) + layer.qweight = torch.nn.Parameter( + _unpack_weight_from_int32(weight=layer.qweight.data, pack_factor=self.pack_factor), requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + qweight = layer.qweight + if bias is not None and bias.dtype == torch.bfloat16: + bias = bias.float() + + reshaped_x = x.reshape(-1, x.shape[-1]) + out = torch_npu.npu_weight_quant_batchmatmul( + reshaped_x, + qweight, + antiquant_scale=layer.scales, + antiquant_offset=layer.qzeros, + antiquant_group_size=self.group_size, + bias=bias, + ) + out_shape = x.shape[:-1] + (qweight.shape[-1] * self.pack_factor,) + return out.reshape(out_shape) + + +@register_scheme("W4A16_AWQ", "moe") +class AscendW4A16AWQFusedMoEMethod(AscendMoEScheme): + """FusedMoE method for Ascend W4A16 AWQ quantization.""" + + quant_type: QuantType = QuantType.W4A16_AWQ + weight_attrs: dict = {"is_transposed": True} + + def __init__(self, quant_config: "AWQConfig"): + self.quant_config = quant_config + self.pack_factor = self.quant_config.pack_factor + self.group_size = self.quant_config.group_size + self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb + + def get_weight( + self, + num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype, + ) -> dict[str, Any]: + assert intermediate_size_per_partition % self.pack_factor == 0, ( + f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} " + f"can be divided by `pack_factor` {self.pack_factor}" + ) + assert hidden_sizes % self.pack_factor == 0, ( + f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}" + ) + + param_dict = {} + param_dict["w13_qweight"] = torch.empty( + num_experts, + hidden_sizes, + 2 * intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32, + ) + param_dict["w2_qweight"] = torch.empty( + num_experts, + intermediate_size_per_partition, + hidden_sizes // self.pack_factor, + dtype=torch.int32, + ) + return param_dict + + def get_dynamic_quant_param( + self, + num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype, + ) -> dict[str, Any]: + assert intermediate_size_per_partition % self.group_size == 0, ( + f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} " + f"can be divided by `group_size` {self.group_size}" + ) + assert hidden_sizes % self.group_size == 0, ( + f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}" + ) + + param_dict = {} + num_groups_w13 = hidden_sizes // self.group_size + num_groups_w2 = intermediate_size_per_partition // self.group_size + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + param_dict["w13_scales"] = torch.empty( + num_experts, + num_groups_w13, + intermediate_size_per_partition * 2, + dtype=params_dtype, + ) + param_dict["w2_scales"] = torch.empty(num_experts, num_groups_w2, hidden_sizes, dtype=params_dtype) + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + param_dict["w13_qzeros"] = torch.empty( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32, + ) + param_dict["w2_qzeros"] = torch.empty( + num_experts, + num_groups_w2, + hidden_sizes // self.pack_factor, + dtype=torch.int32, + ) + return param_dict + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w13_qzeros = torch.nn.Parameter( + _unpack_qzero_from_int32( + weight=layer.w13_qzeros.data, + param_dtype=layer.w13_scales.data.dtype, + pack_factor=self.pack_factor, + is_moe_layer=True, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + w13_qweight = ( + torch.nn.Parameter( + _unpack_weight_from_int32( + weight=layer.w13_qweight.data, + pack_factor=self.pack_factor, + ), + requires_grad=False, + ), + ) + layer.register_parameter("w13_qweight", w13_qweight) + + w2_qzeros = torch.nn.Parameter( + _unpack_qzero_from_int32( + weight=layer.w2_qzeros.data, + param_dtype=layer.w2_scales.data.dtype, + pack_factor=self.pack_factor, + is_moe_layer=True, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + w2_qweight = torch.nn.Parameter( + _unpack_weight_from_int32( + weight=layer.w2_qweight.data, + pack_factor=self.pack_factor, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: torch.Tensor | None = None, + global_redundant_expert_num=0, + pertoken_scale: torch.Tensor | None = None, + activation: MoEActivation = MoEActivation.SILU, + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + assert activation == MoEActivation.SILU, "Only SiLU activation is supported." + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = _EXTRA_CTX.moe_comm_method + return moe_comm_method.fused_experts( + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=layer.w13_qweight, + w2=layer.w2_qweight, + quant_type=self.quant_type, + dynamic_eplb=self.dynamic_eplb, + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + log2phy=log2phy, + pertoken_scale=pertoken_scale, + activation=activation, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_offset=layer.w13_qzeros, + w2_offset=layer.w2_qzeros, + ) + ) diff --git a/vllm_ascend/quantization/quant_type.py b/vllm_ascend/quantization/quant_type.py index d252adf4da0..7d01eba0feb 100644 --- a/vllm_ascend/quantization/quant_type.py +++ b/vllm_ascend/quantization/quant_type.py @@ -32,3 +32,4 @@ class QuantType(Enum): MXFP8 = 3 W4A16 = 4 MXFP4 = 5 + W4A16_AWQ = 6 diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index b39a9f27cce..2a4fc6e724b 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -44,6 +44,8 @@ COMPILATION_PASS_KEY = "graph_fusion_manager" ASCEND_QUANTIZATION_METHOD = "ascend" +# AWQ quantization method identifier for Ascend NPU +AWQ_QUANTIZATION_METHOD = "awq" COMPRESSED_TENSORS_METHOD = "compressed-tensors" SOC_VERSION_INFERENCE_SERIES = ["Ascend310P3"] REGISTERED_ASCEND_OPS = {}