From 5ed63ef06ffc94eab58b9fcc44fd7536059c33a1 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Thu, 8 Jan 2026 21:04:29 +0800 Subject: [PATCH 01/20] refactor quantization Signed-off-by: SlightwindSec --- .../feature_guide/quantization.md | 23 +- ...ant_config.py => test_modelslim_config.py} | 54 +- tests/ut/quantization/test_utils.py | 50 -- tests/ut/quantization/test_w4a16.py | 15 +- .../test_w4a4_flatquant_dynamic.py | 12 +- tests/ut/quantization/test_w4a8_dynamic.py | 12 +- tests/ut/quantization/test_w8a16.py | 2 +- tests/ut/quantization/test_w8a8.py | 32 +- tests/ut/quantization/test_w8a8_dynamic.py | 10 +- tests/ut/test_platform.py | 8 +- vllm_ascend/attention/mla_v1.py | 2 +- vllm_ascend/attention/sfa_v1.py | 2 +- vllm_ascend/ops/fused_moe/fused_moe.py | 6 +- vllm_ascend/ops/linear_op.py | 7 +- vllm_ascend/platform.py | 5 +- vllm_ascend/quantization/__init__.py | 38 ++ .../compressed_tensors/__init__.py | 0 ...ensors.py => compressed_tensors_config.py} | 107 ++-- vllm_ascend/quantization/methods/__init__.py | 69 +++ vllm_ascend/quantization/methods/base.py | 218 +++++++ vllm_ascend/quantization/methods/registry.py | 62 ++ .../quantization/{ => methods}/w4a16.py | 569 +++++++++--------- .../w4a4_flatquant.py} | 27 +- .../{w4a8_dynamic.py => methods/w4a8.py} | 33 +- .../quantization/{ => methods}/w8a16.py | 26 +- .../{ => methods}/w8a8_dynamic.py | 60 +- .../quantization/{ => methods}/w8a8_pdmix.py | 49 +- .../{w8a8.py => methods/w8a8_static.py} | 34 +- vllm_ascend/quantization/modelslim_config.py | 408 +++++++++++++ vllm_ascend/quantization/utils.py | 115 ---- .../{quant_config.py => wrappers.py} | 346 ++--------- 31 files changed, 1405 insertions(+), 996 deletions(-) rename tests/ut/quantization/{test_quant_config.py => test_modelslim_config.py} (70%) delete mode 100644 tests/ut/quantization/test_utils.py delete mode 100644 vllm_ascend/quantization/compressed_tensors/__init__.py rename vllm_ascend/quantization/{compressed_tensors/compressed_tensors.py => compressed_tensors_config.py} (76%) create mode 100644 vllm_ascend/quantization/methods/__init__.py create mode 100644 vllm_ascend/quantization/methods/base.py create mode 100644 vllm_ascend/quantization/methods/registry.py rename vllm_ascend/quantization/{ => methods}/w4a16.py (95%) rename vllm_ascend/quantization/{w4a4_flatquant_dynamic.py => methods/w4a4_flatquant.py} (93%) rename vllm_ascend/quantization/{w4a8_dynamic.py => methods/w4a8.py} (97%) rename vllm_ascend/quantization/{ => methods}/w8a16.py (84%) rename vllm_ascend/quantization/{ => methods}/w8a8_dynamic.py (94%) rename vllm_ascend/quantization/{ => methods}/w8a8_pdmix.py (56%) rename vllm_ascend/quantization/{w8a8.py => methods/w8a8_static.py} (88%) create mode 100644 vllm_ascend/quantization/modelslim_config.py delete mode 100644 vllm_ascend/quantization/utils.py rename vllm_ascend/quantization/{quant_config.py => wrappers.py} (52%) diff --git a/docs/source/developer_guide/feature_guide/quantization.md b/docs/source/developer_guide/feature_guide/quantization.md index e84db9c2005..5784243edb7 100644 --- a/docs/source/developer_guide/feature_guide/quantization.md +++ b/docs/source/developer_guide/feature_guide/quantization.md @@ -10,7 +10,7 @@ The current process for registering and obtaining quantization methods in vLLM A ![get_quant_method](../../assets/quantization/get_quant_method.png) -vLLM Ascend registers a custom ascend quantization method. By configuring the `--quantization ascend` parameter (or `quantization="ascend"` for offline), the quantization feature is enabled. When constructing the `quant_config`, the registered `AscendQuantConfig` is initialized and `get_quant_method` is called to obtain the quantization method corresponding to each weight part, stored in the `quant_method` attribute. +vLLM Ascend registers a custom ascend quantization method. By configuring the `--quantization ascend` parameter (or `quantization="ascend"` for offline), the quantization feature is enabled. When constructing the `quant_config`, the registered `AscendModelSlimConfig` is initialized and `get_quant_method` is called to obtain the quantization method corresponding to each weight part, stored in the `quant_method` attribute. Currently supported quantization methods include `AscendLinearMethod`, `AscendFusedMoEMethod`, `AscendEmbeddingMethod`, and their corresponding non-quantized methods: @@ -51,18 +51,21 @@ Based on the above content, we present a brief description of the adaptation pro ### Quantization Algorithm Adaptation - **Step 1: Algorithm Design**. Define the algorithm ID (e.g., `W4A8_DYNAMIC`), determine supported layers (linear, moe, attention), and design the quantization scheme (static/dynamic, pertensor/perchannel/pergroup). -- **Step 2: Registration**. Add the algorithm ID to `ASCEND_QUANTIZATION_METHOD_MAP` in `vllm_ascend/quantization/utils.py` and associate it with the corresponding method class. +- **Step 2: Registration**. Use the `@register_scheme` decorator in `vllm_ascend/quantization/methods/registry.py` to register your quantization scheme class. ```python -ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { - "W4A8_DYNAMIC": { - "linear": AscendW4A8DynamicLinearMethod, - "moe": AscendW4A8DynamicFusedMoEMethod, - }, -} +from vllm_ascend.quantization.methods import register_scheme, AscendLinearScheme + +@register_scheme("W4A8_DYNAMIC", "linear") +class AscendW4A8DynamicLinearMethod(AscendLinearScheme): + ... + +@register_scheme("W4A8_DYNAMIC", "moe") +class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): + ... ``` -- **Step 3: Implementation**. Create an algorithm implementation file, such as `vllm_ascend/quantization/w4a8_dynamic.py`, and implement the method class and logic. +- **Step 3: Implementation**. Create an algorithm implementation file, such as `vllm_ascend/quantization/methods/w4a8.py`, and implement the method class and logic. - **Step 4: Testing**. Use your algorithm to generate quantization configurations and verify correctness and performance on target models and hardware. ### Quantized Model Adaptation @@ -70,7 +73,7 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { Adapting a new quantized model requires ensuring the following three points: - The original model has been successfully adapted in `vLLM Ascend`. -- **Fused Module Mapping**: Add the model's `model_type` to `packed_modules_model_mapping` in `vllm_ascend/quantization/quant_config.py` (e.g., `qkv_proj`, `gate_up_proj`, `experts`) to ensure sharding consistency and correct loading. +- **Fused Module Mapping**: Add the model's `model_type` to `packed_modules_model_mapping` in `vllm_ascend/quantization/modelslim_config.py` (e.g., `qkv_proj`, `gate_up_proj`, `experts`) to ensure sharding consistency and correct loading. ```python packed_modules_model_mapping = { diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_modelslim_config.py similarity index 70% rename from tests/ut/quantization/test_quant_config.py rename to tests/ut/quantization/test_modelslim_config.py index f75f8042b56..f290e74259d 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -7,11 +7,11 @@ from tests.ut.base import TestBase from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod -from vllm_ascend.quantization.quant_config import AscendQuantConfig +from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD -class TestAscendQuantConfig(TestBase): +class TestAscendModelSlimConfig(TestBase): def setUp(self): self.sample_config = { @@ -25,7 +25,7 @@ def setUp(self): "shard1.weight": "FLOAT", "shard2.weight": "FLOAT", } - self.ascend_config = AscendQuantConfig(self.sample_config) + self.ascend_config = AscendModelSlimConfig(self.sample_config) self.ascend_config.packed_modules_mapping = None def test_init(self): @@ -34,55 +34,55 @@ def test_init(self): def test_repr(self): repr_str = repr(self.ascend_config) - self.assertTrue(repr_str.startswith("AscendQuantConfig:\n")) + self.assertTrue(repr_str.startswith("AscendModelSlimConfig:\n")) def test_get_name(self): - self.assertEqual(AscendQuantConfig.get_name(), + self.assertEqual(AscendModelSlimConfig.get_name(), ASCEND_QUANTIZATION_METHOD) def test_get_supported_act_dtypes(self): - supported_dtypes = AscendQuantConfig.get_supported_act_dtypes() + supported_dtypes = AscendModelSlimConfig.get_supported_act_dtypes() self.assertEqual(len(supported_dtypes), 3) def test_get_min_capability(self): with self.assertRaises(NotImplementedError): - AscendQuantConfig.get_min_capability() + AscendModelSlimConfig.get_min_capability() def test_get_config_filenames(self): - filenames = AscendQuantConfig.get_config_filenames() + filenames = AscendModelSlimConfig.get_config_filenames() self.assertEqual(filenames, ["quant_model_description.json"]) def test_from_config(self): - config = AscendQuantConfig.from_config(self.sample_config) - self.assertIsInstance(config, AscendQuantConfig) + config = AscendModelSlimConfig.from_config(self.sample_config) + self.assertIsInstance(config, AscendModelSlimConfig) self.assertEqual(config.quant_description, self.sample_config) @patch('torch.npu.is_available') def test_override_quantization_method(self, mock_is_available): # Test when NPU is available mock_is_available.return_value = True - result = AscendQuantConfig.override_quantization_method(None, None) + result = AscendModelSlimConfig.override_quantization_method(None, None) self.assertIsNone(result) hf_quant_cfg = {"quant_method": ""} - result = AscendQuantConfig.override_quantization_method( + result = AscendModelSlimConfig.override_quantization_method( hf_quant_cfg, None) self.assertEqual(result, "ascend") # Test when NPU is not available mock_is_available.return_value = False - result = AscendQuantConfig.override_quantization_method(None, None) + result = AscendModelSlimConfig.override_quantization_method(None, None) self.assertIsNone(result) hf_quant_cfg = {"quant_method": ""} - result = AscendQuantConfig.override_quantization_method( + result = AscendModelSlimConfig.override_quantization_method( hf_quant_cfg, None) self.assertIsNone(result) def test_get_quant_method_for_linear(self): mock_config = MagicMock() - mock_config.model_config.hf_text_config.model_type = None + mock_config.model_config.hf_config.model_type = None linear_layer = MagicMock(spec=LinearBase) # Test skipped layer - with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ + with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ patch.object(self.ascend_config, \ 'is_layer_skipped_ascend', return_value=True): @@ -91,8 +91,8 @@ def test_get_quant_method_for_linear(self): # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ - patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: + patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ + patch('vllm_ascend.quantization.modelslim_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: method = self.ascend_config.get_quant_method(linear_layer, ".attn") self.assertIs(method, mock_ascend_linear.return_value) @@ -103,9 +103,9 @@ def test_get_quant_method_for_linear(self): def test_get_quant_method_for_attention(self): attention_layer = MagicMock(spec=Attention) mock_config = MagicMock() - mock_config.model_config.hf_text_config.model_type = None - with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \ + mock_config.model_config.hf_config.model_type = None + with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ + patch('vllm_ascend.quantization.modelslim_config.AscendKVCacheMethod', \ return_value=MagicMock()) as mock_ascend_kvcache: # Test with fa_quant_type method = self.ascend_config.get_quant_method( @@ -117,20 +117,20 @@ def test_get_quant_method_for_fused_moe(self): fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig) fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig) mock_config = MagicMock() - mock_config.model_config.hf_text_config.model_type = None + mock_config.model_config.hf_config.model_type = None # Test skipped layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \ - patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ + patch('vllm_ascend.quantization.modelslim_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") self.assertIs(method, mock_ascend_moe.return_value) # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ - patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ + patch('vllm_ascend.quantization.modelslim_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") self.assertIs(method, mock_ascend_moe.return_value) @@ -150,7 +150,7 @@ def test_is_layer_skipped_ascend(self): # Test inconsistent fused layer shards bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"} - config = AscendQuantConfig(bad_config) + config = AscendModelSlimConfig(bad_config) with self.assertRaises(ValueError): config.is_layer_skipped_ascend("fused_layer", fused_mapping) diff --git a/tests/ut/quantization/test_utils.py b/tests/ut/quantization/test_utils.py deleted file mode 100644 index f4cbc0c26c6..00000000000 --- a/tests/ut/quantization/test_utils.py +++ /dev/null @@ -1,50 +0,0 @@ -import types - -from tests.ut.base import TestBase -from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP, - get_quant_method) - - -class TestGetQuantMethod(TestBase): - - def setUp(self): - self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy( - ) - for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): - for layer_type in layer_map.keys(): - ASCEND_QUANTIZATION_METHOD_MAP[quant_type][ - layer_type] = types.new_class(f"{quant_type}_{layer_type}") - - def tearDown(self): - # Restore original map - ASCEND_QUANTIZATION_METHOD_MAP.clear() - ASCEND_QUANTIZATION_METHOD_MAP.update( - self.original_quantization_method_map) - - def test_linear_quant_methods(self): - for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): - if "linear" in layer_map.keys(): - prefix = "linear_layer" - cls = layer_map["linear"] - method = get_quant_method({"linear_layer.weight": quant_type}, - prefix, "linear") - self.assertIsInstance(method, cls) - - def test_moe_quant_methods(self): - for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): - if "moe" in layer_map.keys(): - prefix = "layer" - cls = layer_map["moe"] - method = get_quant_method({"layer.weight": quant_type}, prefix, - "moe") - self.assertIsInstance(method, cls) - - def test_invalid_layer_type(self): - quant_description = {"linear_layer.weight": "W8A8"} - with self.assertRaises(NotImplementedError): - get_quant_method(quant_description, "linear_layer", "unsupported") - - def test_invalid_quant_type(self): - quant_description = {"linear_layer.weight": "UNKNOWN"} - with self.assertRaises(NotImplementedError): - get_quant_method(quant_description, "linear_layer", "linear") diff --git a/tests/ut/quantization/test_w4a16.py b/tests/ut/quantization/test_w4a16.py index 5d50e738904..1258b12dc5d 100644 --- a/tests/ut/quantization/test_w4a16.py +++ b/tests/ut/quantization/test_w4a16.py @@ -3,8 +3,9 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.w4a16 import (AscendW4A16FusedMoEMethod, - pack_to_int32, unpack_from_int32) +from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod, + pack_to_int32, + unpack_from_int32) class TestUnpackFromInt32(TestBase): @@ -42,7 +43,7 @@ def test_unpack_from_int32_assertions(self): class TestPackToInt32(TestBase): @patch( - "vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack" + "vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack" ) def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack): mock_npu_convert_weight_to_int4pack.return_value = torch.zeros( @@ -57,7 +58,7 @@ def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack): self.assertEqual(result.shape, torch.Size([2, 8, 4])) @patch( - "vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack" + "vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack" ) def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack): @@ -97,8 +98,8 @@ class TestAscendW4A16FusedMoEMethod(TestBase): output_size = 128 group_size = 32 - @patch("vllm_ascend.quantization.w4a16.get_ascend_config") - @patch("vllm_ascend.quantization.w4a16.get_current_vllm_config") + @patch("vllm_ascend.quantization.methods.w4a16.get_ascend_config") + @patch("vllm_ascend.quantization.methods.w4a16.get_current_vllm_config") def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config): mock_ascend_config = Mock() mock_ascend_config.dynamic_eplb = False @@ -218,7 +219,7 @@ def build_layer(self): return layer @patch( - "vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack" + "vllm_ascend.quantization.methods.w4a16.torch_npu.npu_convert_weight_to_int4pack" ) def test_process_weights_after_loading_with_transpose( self, mock_npu_convert_weight_to_int4pack): diff --git a/tests/ut/quantization/test_w4a4_flatquant_dynamic.py b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py index d02ad6bddd1..c3f452c4ea5 100644 --- a/tests/ut/quantization/test_w4a4_flatquant_dynamic.py +++ b/tests/ut/quantization/test_w4a4_flatquant_dynamic.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn -from vllm_ascend.quantization.w4a4_flatquant_dynamic import ( +from vllm_ascend.quantization.methods.w4a4_flatquant import ( AscendW4A4FlatQuantDynamicLinearMethod, get_decompose_dim, pack_int4_weights) @@ -33,7 +33,7 @@ def test_get_decompose_dim(self): self.assertEqual(get_decompose_dim(100), (10, 10)) self.assertEqual(get_decompose_dim(99), (9, 11)) - @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') + @patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu') def test_pack_int4_weights_npu_success(self, mock_torch_npu): """ Tests weight packing using the mocked NPU kernel. @@ -119,7 +119,7 @@ def _prepare_apply_mocks_and_layer(self, batch_size): x = torch.randn(batch_size, self.input_size, dtype=self.params_dtype) return layer, x, m, n - @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') + @patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu') def test_apply_small_batch(self, mock_torch_npu): """Tests the apply method with a batch size smaller than MAX_BATCH_SIZE.""" batch_size = 128 @@ -143,9 +143,9 @@ def test_apply_small_batch(self, mock_torch_npu): self.assertEqual(output.shape, (batch_size, self.output_size)) @patch( - 'vllm_ascend.quantization.w4a4_flatquant_dynamic.KRONECKER_QUANT_MAX_BATCH_SIZE', + 'vllm_ascend.quantization.methods.w4a4_flatquant.KRONECKER_QUANT_MAX_BATCH_SIZE', 10) - @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.torch_npu') + @patch('vllm_ascend.quantization.methods.w4a4_flatquant.torch_npu') def test_apply_large_batch(self, mock_torch_npu): """Tests the apply method with a batch size larger than MAX_BATCH_SIZE.""" batch_size = 25 @@ -178,7 +178,7 @@ def test_apply_dimension_mismatch_error(self): ValueError, "FlatQuant transform matrices dimension mismatch"): self.method.apply(layer, x) - @patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights') + @patch('vllm_ascend.quantization.methods.w4a4_flatquant.pack_int4_weights') def test_process_weights_after_loading(self, mock_pack_weights): """Tests weight processing after loading, without transpose.""" layer = nn.Module() diff --git a/tests/ut/quantization/test_w4a8_dynamic.py b/tests/ut/quantization/test_w4a8_dynamic.py index 3ed2a877d2b..4501e37e1cd 100644 --- a/tests/ut/quantization/test_w4a8_dynamic.py +++ b/tests/ut/quantization/test_w4a8_dynamic.py @@ -3,14 +3,14 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.w4a8_dynamic import ( +from vllm_ascend.quantization.methods.w4a8 import ( AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) class TestAscendW4A8DynamicLinearMethod(TestBase): @patch('vllm.distributed.get_tensor_model_parallel_world_size') - @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') + @patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config') def setUp(self, mock_get_current_vllm_config, mock_get_tp_world_size): mock_get_tp_world_size.return_value = 1 mock_vllm_config = Mock() @@ -127,10 +127,10 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase): output_size = 56 group_size = 2 - @patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config') - @patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config') - @patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group') - @patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group') + @patch('vllm_ascend.quantization.methods.w4a8.get_ascend_config') + @patch('vllm_ascend.quantization.methods.w4a8.get_current_vllm_config') + @patch('vllm_ascend.quantization.methods.w4a8.get_ep_group') + @patch('vllm_ascend.quantization.methods.w4a8.get_mc2_group') @patch('torch.distributed.get_rank', return_value=0) def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group, get_current_vllm_config, mock_get_ascend_config): diff --git a/tests/ut/quantization/test_w8a16.py b/tests/ut/quantization/test_w8a16.py index 1d839bfa763..3454cbfdeb8 100644 --- a/tests/ut/quantization/test_w8a16.py +++ b/tests/ut/quantization/test_w8a16.py @@ -4,7 +4,7 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.w8a16 import AscendW8A16LinearMethod +from vllm_ascend.quantization.methods.w8a16 import AscendW8A16LinearMethod class TestAscendW8A16LinearMethod(TestBase): diff --git a/tests/ut/quantization/test_w8a8.py b/tests/ut/quantization/test_w8a8.py index b8639cc4814..bbed09e6c78 100644 --- a/tests/ut/quantization/test_w8a8.py +++ b/tests/ut/quantization/test_w8a8.py @@ -4,36 +4,10 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod, - quant_per_tensor) +from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod from vllm_ascend.utils import AscendDeviceType -class TestQuantPerTensor(TestBase): - - @patch("torch_npu.npu_quantize") - def test_quant_per_tensor(self, mock_npu_quantize): - in_tensor = torch.randn(32, 128) - input_scale = torch.tensor(0.1) - input_offset = torch.tensor(0) - - expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8) - mock_npu_quantize.return_value = expected_output - - output = quant_per_tensor(in_tensor, input_scale, input_offset) - - mock_npu_quantize.assert_called_once_with( - in_tensor, - input_scale, - input_offset, - torch.qint8, - -1, - False, - ) - - self.assertTrue(torch.equal(output, expected_output)) - - class TestAscendW8A8LinearMethod(TestBase): def setUp(self): @@ -63,7 +37,9 @@ def test_get_perchannel_param(self): self.assertEqual(params['weight_scale'].shape, (10, 1)) self.assertEqual(params['weight_offset'].shape, (10, 1)) - @patch("vllm_ascend.quantization.w8a8.get_weight_prefetch_method") + @patch( + "vllm_ascend.quantization.methods.w8a8_static.get_weight_prefetch_method" + ) @patch("torch.ops.vllm.quantize") @patch("torch_npu.npu_quant_matmul") def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, mock_quantize, diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py index ebd368a72a5..00cc8f136a9 100644 --- a/tests/ut/quantization/test_w8a8_dynamic.py +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -3,7 +3,7 @@ import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.w8a8_dynamic import \ +from vllm_ascend.quantization.methods.w8a8_dynamic import \ AscendW8A8DynamicFusedMoEMethod @@ -13,13 +13,13 @@ class TestAscendW8A8FusedMoEMethod(TestBase): intermediate_size = 128 @patch("torch.distributed.get_rank") - @patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group") - @patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config") - @patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group") + @patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_mc2_group") + @patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ascend_config") + @patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_ep_group") def setUp(self, mock_get_ep_group, mock_get_ascend_config, mock_get_mc2_group, mock_get_rank): with patch( - 'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config' + 'vllm_ascend.quantization.methods.w8a8_dynamic.get_current_vllm_config' ) as mock_get_current_vllm_config: mock_vllm_config = Mock() mock_vllm_config.quant_config = Mock( diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 6f2eeec190d..1523ebac1a5 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -55,7 +55,7 @@ def test_is_sleep_mode_available(self): self.assertTrue(self.platform.is_sleep_mode_available()) @patch("vllm_ascend.utils.adapt_patch") - @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + @patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig") def test_pre_register_and_update_with_parser(self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() @@ -71,7 +71,7 @@ def test_pre_register_and_update_with_parser(self, mock_quant_config, self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend @patch("vllm_ascend.utils.adapt_patch") - @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + @patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig") def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch): self.platform.pre_register_and_update(None) @@ -79,7 +79,7 @@ def test_pre_register_and_update_without_parser(self, mock_quant_config, mock_adapt_patch.assert_called_once_with(is_global_patch=True) @patch("vllm_ascend.utils.adapt_patch") - @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + @patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig") def test_pre_register_and_update_with_parser_no_quant_action( self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() @@ -90,7 +90,7 @@ def test_pre_register_and_update_with_parser_no_quant_action( mock_adapt_patch.assert_called_once_with(is_global_patch=True) @patch("vllm_ascend.utils.adapt_patch") - @patch("vllm_ascend.quantization.quant_config.AscendQuantConfig") + @patch("vllm_ascend.quantization.modelslim_config.AscendModelSlimConfig") def test_pre_register_and_update_with_existing_ascend_quant( self, mock_quant_config, mock_adapt_patch): mock_parser = MagicMock() diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 23564888e3c..7d5f2fddd3e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -37,7 +37,7 @@ register_all_layers_to_shard_weight_series) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch -from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod +from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, maybe_trans_nz, weak_ref_tensors) from vllm_ascend.worker.npu_input_batch import NPUInputBatch diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index c4a2a51e58f..1876cd382a3 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -32,7 +32,7 @@ from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch -from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod +from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, enable_sp, maybe_trans_nz, replace_layer) from vllm_ascend.worker.npu_input_batch import NPUInputBatch diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index d2a55e318f2..e5e661fa1ba 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -40,10 +40,8 @@ FusedExpertsResult, setup_moe_comm_method) from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType -from vllm_ascend.quantization.w4a8_dynamic import \ - AscendW4A8DynamicFusedMoEMethod -from vllm_ascend.quantization.w8a8_dynamic import \ - AscendW8A8DynamicFusedMoEMethod +from vllm_ascend.quantization.methods import (AscendW4A8DynamicFusedMoEMethod, + AscendW8A8DynamicFusedMoEMethod) from vllm_ascend.utils import (AscendDeviceType, enable_sp, get_ascend_device_type, maybe_trans_nz, npu_stream_switch, shared_expert_dp_enabled, diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 53130e67e73..fa846224e33 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -363,7 +363,8 @@ def otp_maybe_quant_comm(x): "communication_fn"] = otp_maybe_quant_comm actual_quant_method = getattr(self.quant_method, 'quant_method', self.quant_method) - from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + from vllm_ascend.quantization.methods.w8a8_static import \ + AscendW8A8LinearMethod if not isinstance(actual_quant_method, AscendW8A8LinearMethod): # Check if w8a8 quantization is enabled. If not, communicate immediately. input_parallel = otp_maybe_quant_comm(input_parallel) @@ -548,8 +549,8 @@ def matmul_and_reduce(self, input_parallel: torch.Tensor, from vllm.model_executor.layers.linear import UnquantizedLinearMethod - from vllm_ascend.quantization.quant_config import AscendLinearMethod - from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod + from vllm_ascend.quantization.methods import AscendW8A8LinearMethod + from vllm_ascend.quantization.wrappers import AscendLinearMethod # For unquant if mmrs_fusion and isinstance(self.layer.quant_method, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 06f8be7bca2..29ffefc3400 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -137,10 +137,7 @@ def pre_register_and_update(cls, if ASCEND_QUANTIZATION_METHOD not in quant_action.choices: quant_action.choices.append(ASCEND_QUANTIZATION_METHOD) - from vllm_ascend.quantization.compressed_tensors.compressed_tensors import \ - AscendCompressedTensorsConfig # noqa: F401 - from vllm_ascend.quantization.quant_config import \ - AscendQuantConfig # noqa: F401 + from vllm_ascend.quantization import AscendCompressedTensorsConfig, AscendModelSlimConfig # noqa: F401 config_deprecated_logging() diff --git a/vllm_ascend/quantization/__init__.py b/vllm_ascend/quantization/__init__.py index e69de29bb2d..d5b31e33d88 100644 --- a/vllm_ascend/quantization/__init__.py +++ b/vllm_ascend/quantization/__init__.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Ascend quantization module. + +This module provides quantization support for Ascend NPU. + +Supported quantization tools: +- ModelSlim: Use AscendModelSlimConfig +- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig + +Public API: +- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig +- For scheme implementations, import from vllm_ascend.quantization.methods +""" + +# LLM-Compressor (compressed_tensors) quantization config +from .compressed_tensors_config import AscendCompressedTensorsConfig +# ModelSlim quantization config +from .modelslim_config import AscendModelSlimConfig + +__all__ = [ + "AscendModelSlimConfig", + "AscendCompressedTensorsConfig", +] diff --git a/vllm_ascend/quantization/compressed_tensors/__init__.py b/vllm_ascend/quantization/compressed_tensors/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py b/vllm_ascend/quantization/compressed_tensors_config.py similarity index 76% rename from vllm_ascend/quantization/compressed_tensors/compressed_tensors.py rename to vllm_ascend/quantization/compressed_tensors_config.py index 774bb00628e..a15bdfffa0d 100644 --- a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -1,4 +1,23 @@ -from typing import TYPE_CHECKING, Any, Optional, cast +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +"""LLM-Compressor (compressed_tensors) quantization configuration for Ascend.""" + +from typing import Any, Optional, cast import torch from compressed_tensors.quantization import (QuantizationArgs, @@ -16,34 +35,48 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) +from vllm.model_executor.models.utils import WeightsMapper from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod, - AscendLinearMethod, - AscendQuantConfig) -from vllm_ascend.quantization.w4a16 import AscendW4A16FusedMoEMethod -from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD -if TYPE_CHECKING: - from vllm.model_executor.models.utils import WeightsMapper - logger = init_logger(__name__) -QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, QuantizationArgs]]] - -def remove_quantization_method(): +# Remove the original compressed_tensors method to replace with our implementation +def _remove_quantization_method(): if COMPRESSED_TENSORS_METHOD in QUANTIZATION_METHODS: QUANTIZATION_METHODS.remove(COMPRESSED_TENSORS_METHOD) -remove_quantization_method() +_remove_quantization_method() + +QUANTIZATION_SCHEME_MAP_TYPE = dict[str, Optional[dict[str, + "QuantizationArgs"]]] + + +def get_quant_method_llmcompressor(layer: torch.nn.Module): + """Get quantization method for LLM-Compressor models. + + Args: + layer: The layer module with a scheme attribute. + + Returns: + The scheme from the layer. + """ + logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") + if layer.scheme is None: + raise ValueError("A scheme must be defined for each layer") + return layer.scheme @register_quantization_config(COMPRESSED_TENSORS_METHOD) class AscendCompressedTensorsConfig(QuantizationConfig): + """Config class for LLM-Compressor (compressed_tensors) quantization on Ascend. + + This class adapts the compressed_tensors format to work with Ascend's + quantization implementations. + """ def __init__( self, @@ -93,23 +126,16 @@ def from_config(cls, config: dict[str, @classmethod def _quantization_scheme_map_from_config( cls, config: dict[str, Any]) -> QUANTIZATION_SCHEME_MAP_TYPE: - """ + """Build target scheme map from config. + :param config: The `quantization_config` dictionary from config.json :return: A dictionary mapping target layer names to their corresponding quantization_args for weights and input activations """ + target_scheme_map: dict[str, Any] = dict() quant_format = cast(str, config.get("format")) - # The quant_config has multiple config_groups, each containing - # an input_activations key with details about how the activations are - # quantized, a weights key indicating how the weights are quantized, - # and a list of targets under the `targets` key, dictating which - # layers are impacted by the quantization details. The quantization - # details follow the structure defined by the QuantizationArgs - # pydantic model, which is used to verify the structure of the - # quant_config and also store the details for later use. - config_groups = config.get("config_groups", dict()) for _, quant_config in config_groups.items(): targets = quant_config.get("targets") @@ -140,6 +166,9 @@ def get_quant_method( layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: + from .modelslim_config import AscendModelSlimConfig + from .wrappers import AscendFusedMoEMethod, AscendLinearMethod + if isinstance(layer, LinearBase): layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD # collect schemes @@ -149,8 +178,8 @@ def get_quant_method( quant_method = UnquantizedLinearMethod() if quant_scheme is not None: layer.scheme = quant_scheme - ascend_quant_config = AscendQuantConfig(self.quant_description - or {}) + ascend_quant_config = AscendModelSlimConfig( + self.quant_description or {}) quant_method = AscendLinearMethod(ascend_quant_config, prefix, None, layer) return quant_method @@ -163,8 +192,8 @@ def get_quant_method( quant_method = AscendUnquantizedFusedMoEMethod(layer.moe_config) if quant_scheme is not None: layer.scheme = quant_scheme - ascend_quant_config = AscendQuantConfig(self.quant_description - or {}) + ascend_quant_config = AscendModelSlimConfig( + self.quant_description or {}) quant_method = AscendFusedMoEMethod( ascend_quant_config, prefix, ascend_quant_config.packed_modules_mapping, layer) @@ -175,7 +204,8 @@ def get_scheme(self, layer: torch.nn.Module, layer_name: Optional[str] = None ) -> Optional["CompressedTensorsScheme"]: - """ + """Get the quantization scheme for a layer. + compressed-tensors supports non uniform in the following way: targets of config_groups: There can be N config_groups which each @@ -224,8 +254,13 @@ def get_scheme(self, return scheme def _get_scheme_from_parts( - self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> "CompressedTensorsScheme": + self, weight_quant: "QuantizationArgs", + input_quant: "QuantizationArgs") -> "CompressedTensorsScheme": + """Determine the appropriate scheme based on quantization args.""" + from .methods import (AscendW4A16FusedMoEMethod, + AscendW8A8DynamicLinearMethod, + AscendW8A8LinearMethod) + act_quant_format = is_activation_quantization_format(self.quant_format) if act_quant_format and input_quant is not None: if self._is_static_tensor_w8a8(weight_quant, input_quant): @@ -241,8 +276,8 @@ def _get_scheme_from_parts( raise NotImplementedError( "No compressed-tensors compatible scheme was found.") - def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", + input_quant: "QuantizationArgs") -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value) @@ -255,8 +290,8 @@ def _is_static_tensor_w8a8(self, weight_quant: QuantizationArgs, # Only symmetric weight quantization supported. return is_8_bits and is_tensor and is_symmetric and is_static - def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, - input_quant: QuantizationArgs) -> bool: + def _is_dynamic_token_w8a8(self, weight_quant: "QuantizationArgs", + input_quant: "QuantizationArgs") -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( weight_quant.strategy == QuantizationStrategy.CHANNEL.value) @@ -269,7 +304,7 @@ def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, # Only symmetric weight quantization supported. return is_8_bits and is_token and is_symmetric and is_dynamic - def _is_w4a16(self, weight_quant: QuantizationArgs) -> bool: + def _is_w4a16(self, weight_quant: "QuantizationArgs") -> bool: is_4_bits = weight_quant.num_bits == 4 return is_4_bits diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py new file mode 100644 index 00000000000..bcf73cfd239 --- /dev/null +++ b/vllm_ascend/quantization/methods/__init__.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Ascend quantization scheme implementations. + +This module provides all quantization scheme implementations for Ascend NPU. +Schemes are automatically registered via the @register_scheme decorator. + +Usage: + from vllm_ascend.quantization.methods import get_scheme_class + + # Get a scheme class by quant_type and layer_type + scheme_cls = get_scheme_class("W8A8_DYNAMIC", "linear") + scheme = scheme_cls() +""" + +# Import base classes +from .base import AscendLinearScheme, AscendMoEScheme +# Import registry functions +from .registry import get_scheme_class, register_scheme +from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod +from .w4a8 import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) +from .w4a16 import AscendW4A16FusedMoEMethod +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) +from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, + AscendW8A8PDMixLinearMethod) +from .w8a8_static import AscendW8A8LinearMethod +from .w8a16 import AscendW8A16LinearMethod + +__all__ = [ + # Base classes + "AscendLinearScheme", + "AscendMoEScheme", + # Registry functions + "register_scheme", + "get_scheme_class", + # W8A8 static + "AscendW8A8LinearMethod", + # W8A8 dynamic + "AscendW8A8DynamicLinearMethod", + "AscendW8A8DynamicFusedMoEMethod", + # W8A8 PDMix + "AscendW8A8PDMixLinearMethod", + "AscendW8A8PDMixFusedMoeMethod", + # W8A16 + "AscendW8A16LinearMethod", + # W4A8 + "AscendW4A8DynamicLinearMethod", + "AscendW4A8DynamicFusedMoEMethod", + # W4A16 + "AscendW4A16FusedMoEMethod", + # W4A4 FlatQuant + "AscendW4A4FlatQuantDynamicLinearMethod", +] diff --git a/vllm_ascend/quantization/methods/base.py b/vllm_ascend/quantization/methods/base.py new file mode 100644 index 00000000000..ee277b927cc --- /dev/null +++ b/vllm_ascend/quantization/methods/base.py @@ -0,0 +1,218 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Abstract base classes for Ascend quantization schemes.""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Optional + +import torch + + +class AscendLinearScheme(ABC): + """Base class for all linear quantization schemes. + + Subclasses must implement get_weight() and apply() methods. + Other methods have default implementations that return empty dicts + or do nothing. + """ + + @abstractmethod + def get_weight(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Return weight tensor specifications. + + Args: + input_size: Input dimension of the linear layer. + output_size: Output dimension of the linear layer. + params_dtype: Data type for parameters. + + Returns: + Dictionary mapping parameter names to empty tensors with + the correct shape and dtype. + """ + ... + + def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: + """Return per-tensor parameter specifications (e.g., input_scale). + + Args: + params_dtype: Data type for parameters. + + Returns: + Dictionary mapping parameter names to empty tensors. + """ + return {} + + def get_perchannel_param(self, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Return per-channel parameter specifications (e.g., weight_scale). + + Args: + output_size: Output dimension of the linear layer. + params_dtype: Data type for parameters. + + Returns: + Dictionary mapping parameter names to empty tensors. + """ + return {} + + def get_pergroup_param(self, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + layer_type: Optional[str] = None) -> Dict[str, Any]: + """Return per-group parameter specifications. + + Args: + input_size: Input dimension of the linear layer. + output_size: Output dimension of the linear layer. + params_dtype: Data type for parameters. + layer_type: Type of layer (e.g., "row" for RowParallelLinear). + + Returns: + Dictionary mapping parameter names to empty tensors. + """ + return {} + + @abstractmethod + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0) -> torch.Tensor: + """Forward computation. + + Args: + layer: The linear layer module. + x: Input tensor. + bias: Optional bias tensor. + tp_rank: Tensor parallel rank. + + Returns: + Output tensor after quantized linear operation. + """ + ... + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Post-loading weight processing (transpose, format conversion, etc.). + + Args: + layer: The linear layer module. + """ + pass + + +class AscendMoEScheme(ABC): + """Base class for all MoE quantization schemes. + + Subclasses must implement get_weight(), get_dynamic_quant_param(), + and apply() methods. + """ + + @abstractmethod + def get_weight(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Return weight tensor specifications for MoE layer. + + Args: + num_experts: Number of experts. + intermediate_size_per_partition: Intermediate size per partition. + hidden_sizes: Hidden dimension size. + params_dtype: Data type for parameters. + + Returns: + Dictionary mapping parameter names to empty tensors. + """ + ... + + @abstractmethod + def get_dynamic_quant_param(self, num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Return dynamic quantization parameters for MoE layer. + + Args: + num_experts: Number of experts. + intermediate_size_per_partition: Intermediate size per partition. + hidden_sizes: Hidden dimension size. + params_dtype: Data type for parameters. + + Returns: + Dictionary mapping parameter names to empty tensors. + """ + ... + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + **kwargs, + ) -> torch.Tensor: + """Forward computation for MoE layer. + + Args: + layer: The MoE layer module. + x: Input hidden states. + router_logits: Router logits for expert selection. + top_k: Number of experts to select per token. + renormalize: Whether to renormalize expert weights. + use_grouped_topk: Whether to use grouped top-k selection. + global_num_experts: Total number of experts globally. + expert_map: Mapping from local to global expert indices. + topk_group: Group size for grouped top-k. + num_expert_group: Number of expert groups. + custom_routing_function: Custom routing function. + scoring_func: Scoring function name. + routed_scaling_factor: Scaling factor for routed experts. + e_score_correction_bias: Expert score correction bias. + is_prefill: Whether in prefill phase. + enable_force_load_balance: Whether to force load balancing. + log2phy: Logical to physical expert mapping. + global_redundant_expert_num: Number of redundant experts. + **kwargs: Additional keyword arguments. + + Returns: + Output tensor after MoE computation. + """ + ... + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Post-loading weight processing for MoE layer. + + Args: + layer: The MoE layer module. + """ + pass diff --git a/vllm_ascend/quantization/methods/registry.py b/vllm_ascend/quantization/methods/registry.py new file mode 100644 index 00000000000..597664022b6 --- /dev/null +++ b/vllm_ascend/quantization/methods/registry.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, Optional, Tuple, Type + +# Registry: maps (quant_type, layer_type) -> SchemeClass +_SCHEME_REGISTRY: Dict[Tuple[str, str], Type[Any]] = {} + + +def register_scheme(quant_type: str, layer_type: str): + """Decorator to register a quantization scheme. + + Args: + quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC"). + layer_type: Layer type (e.g., "linear", "moe"). + + Returns: + Decorator function that registers the class. + + Example: + @register_scheme("W8A8_DYNAMIC", "linear") + class W8A8DynamicLinearScheme(AscendLinearScheme): + ... + """ + + def decorator(cls: Type[Any]) -> Type[Any]: + key = (quant_type, layer_type) + if key in _SCHEME_REGISTRY: + raise ValueError( + f"Scheme already registered for {quant_type}/{layer_type}: " + f"{_SCHEME_REGISTRY[key].__name__}") + _SCHEME_REGISTRY[key] = cls + return cls + + return decorator + + +def get_scheme_class(quant_type: str, layer_type: str) -> Optional[Type[Any]]: + """Get scheme class for given quant_type and layer_type. + + Args: + quant_type: Quantization type (e.g., "W8A8", "W8A8_DYNAMIC"). + layer_type: Layer type (e.g., "linear", "moe"). + + Returns: + The registered scheme class, or None if not found. + """ + return _SCHEME_REGISTRY.get((quant_type, layer_type)) diff --git a/vllm_ascend/quantization/w4a16.py b/vllm_ascend/quantization/methods/w4a16.py similarity index 95% rename from vllm_ascend/quantization/w4a16.py rename to vllm_ascend/quantization/methods/w4a16.py index c6eb379dc3d..edfd7d0c842 100644 --- a/vllm_ascend/quantization/w4a16.py +++ b/vllm_ascend/quantization/methods/w4a16.py @@ -1,284 +1,285 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Any, Callable, Dict, Optional - -import torch -import torch_npu -from vllm.config import get_current_vllm_config -from vllm.forward_context import get_forward_context - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.fused_moe.experts_selector import select_experts - - -def unpack_from_int32( - weight: torch.Tensor, - shape: torch.Size, - num_bits: int, - packed_dim: int = 1, -) -> torch.Tensor: - """ - Unpacks quantized weights from int32 format back to original bits. - - :param weight: The packed int32 tensor containing quantized weights - :param shape: Original shape to restore, defaults to None - :param num_bits: The number of bits used for quantization (<= 8) - :param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1 - :return: Unpacked tensor with int8 dtype after applying offset correction - """ - assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}." - assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}." - - pack_factor = 32 // num_bits - mask = (1 << num_bits) - 1 - - if packed_dim == 1: - unpacked_weight = torch.zeros( - (weight.shape[0], weight.shape[1] * pack_factor), - device=weight.device, - dtype=torch.int32, - ) - for i in range(pack_factor): - unpacked_weight[:, i::pack_factor] = (weight >> - (num_bits * i)) & mask - original_row_size = int(shape[1]) - unpacked_weight = unpacked_weight[:, :original_row_size] - else: - unpacked_weight = torch.zeros( - (weight.shape[0] * pack_factor, weight.shape[1]), - device=weight.device, - dtype=torch.int32, - ) - for i in range(pack_factor): - unpacked_weight[i::pack_factor, :] = (weight >> - (num_bits * i)) & mask - original_row_size = int(shape[0]) - unpacked_weight = unpacked_weight[:original_row_size, :] - - offset = pow(2, num_bits) // 2 - unpacked_weight = (unpacked_weight - offset).to(torch.int8) - - return unpacked_weight - - -def pack_to_int32(weight: torch.Tensor) -> torch.Tensor: - """ - Packs quantized weights into int32 format for storage. - - :param weight: The 3D tensor to pack, must be int8 or int32 dtype - :return: Packed tensor with int32 dtype optimized for storage - """ - assert weight.dim( - ) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." - assert weight.dtype in [ - torch.int8, torch.int32 - ], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}." - - if weight.dtype == torch.int32: - assert weight.shape[ - -1] % 8 == 0, "the last dim of weight needs to be divided by 8." - packed_weight = torch_npu.npu_convert_weight_to_int4pack( - weight.flatten(0, 1)) - packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], - -1) - else: - assert weight.shape[ - -1] % 4 == 0, "the last dim of weight needs to be divided by 4." - packed_weight = weight.view(torch.int32).contiguous() - - return packed_weight - - -class AscendW4A16FusedMoEMethod: - """FusedMoe method for Ascend W4A16. - """ - - def __init__(self) -> None: - self.transpose_weight = True - self.num_bits = 4 # dtype = torch.int4 - self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32 - - vllm_config = get_current_vllm_config() - self.group_size = vllm_config.quant_config.quant_description.get( - "group_size", 32) - ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path - - def get_weight( - self, - num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype, - ) -> Dict[str, Any]: - assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}" - assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}" - - param_dict = {} - - param_dict["w13_weight_packed"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.pack_factor, - dtype=torch.int32) - param_dict["w2_weight_packed"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.pack_factor, - dtype=torch.int32) - - return param_dict - - def get_dynamic_quant_param( - self, - num_experts: int, - intermediate_size_per_partition: int, - hidden_sizes: int, - params_dtype: torch.dtype, - ) -> Dict[str, Any]: - assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}" - assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}" - - param_dict = {} - - param_dict["w13_weight_scale"] = torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.bfloat16) - param_dict["w2_weight_scale"] = torch.empty( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.bfloat16) - param_dict["w13_weight_shape"] = torch.empty(num_experts, - 2, - dtype=torch.int32) - param_dict["w2_weight_shape"] = torch.empty(num_experts, - 2, - dtype=torch.int32) - param_dict["w13_weight_offset"] = torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - hidden_sizes // self.group_size, - dtype=torch.bfloat16) - param_dict["w2_weight_offset"] = torch.zeros( - num_experts, - hidden_sizes, - intermediate_size_per_partition // self.group_size, - dtype=torch.bfloat16) - - return param_dict - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: bool = True, - enable_force_load_balance: bool = True, - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - **kwargs, - ) -> torch.Tensor: - assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) - - topk_ids = topk_ids.to(torch.int32) - topk_weights = topk_weights.to(x.dtype) - - moe_comm_method = get_forward_context().moe_comm_method - return moe_comm_method.fused_experts( - hidden_states=x, - w1=layer.w13_weight_packed, - w2=layer.w2_weight_packed, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_offset=layer.w13_weight_offset, - w2_offset=layer.w2_weight_offset, - topk_weights=topk_weights, - topk_ids=topk_ids, - use_int4_w4a16=True, - expert_map=expert_map, - log2phy=log2phy, - shared_experts=shared_experts, - quantized_x_for_share=quantized_x_for_share, - dynamic_scale_for_share=dynamic_scale_for_share, - dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask", None)) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if self.transpose_weight: - w13_shape = layer.w13_weight_packed.data.shape - w2_shape = layer.w2_weight_packed.data.shape - unpacked_w13_weight = (unpack_from_int32( - layer.w13_weight_packed.data.flatten(0, 1), - torch.Size([ - w13_shape[0] * w13_shape[1], - w13_shape[2] * self.pack_factor - ]), - self.num_bits, - ).view(w13_shape[0], w13_shape[1], - -1).transpose(1, 2).contiguous().int()) - unpacked_w2_weight = (unpack_from_int32( - layer.w2_weight_packed.data.flatten(0, 1), - torch.Size([ - w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor - ]), - self.num_bits, - ).view(w2_shape[0], w2_shape[1], - -1).transpose(1, 2).contiguous().int()) - layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight) - layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight) - - layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( - 1, 2).contiguous() - layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( - 1, 2).contiguous() - - layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose( - 1, 2).contiguous() - layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose( - 1, 2).contiguous() +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ops.fused_moe.experts_selector import select_experts + +from .base import AscendMoEScheme +from .registry import register_scheme + + +def unpack_from_int32( + weight: torch.Tensor, + shape: torch.Size, + num_bits: int, + packed_dim: int = 1, +) -> torch.Tensor: + """Unpacks quantized weights from int32 format back to original bits. + + :param weight: The packed int32 tensor containing quantized weights + :param shape: Original shape to restore, defaults to None + :param num_bits: The number of bits used for quantization (<= 8) + :param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1 + :return: Unpacked tensor with int8 dtype after applying offset correction + """ + assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}." + assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}." + + pack_factor = 32 // num_bits + mask = (1 << num_bits) - 1 + + if packed_dim == 1: + unpacked_weight = torch.zeros( + (weight.shape[0], weight.shape[1] * pack_factor), + device=weight.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked_weight[:, i::pack_factor] = (weight >> + (num_bits * i)) & mask + original_row_size = int(shape[1]) + unpacked_weight = unpacked_weight[:, :original_row_size] + else: + unpacked_weight = torch.zeros( + (weight.shape[0] * pack_factor, weight.shape[1]), + device=weight.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked_weight[i::pack_factor, :] = (weight >> + (num_bits * i)) & mask + original_row_size = int(shape[0]) + unpacked_weight = unpacked_weight[:original_row_size, :] + + offset = pow(2, num_bits) // 2 + unpacked_weight = (unpacked_weight - offset).to(torch.int8) + + return unpacked_weight + + +def pack_to_int32(weight: torch.Tensor) -> torch.Tensor: + """Packs quantized weights into int32 format for storage. + + :param weight: The 3D tensor to pack, must be int8 or int32 dtype + :return: Packed tensor with int32 dtype optimized for storage + """ + assert weight.dim( + ) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." + assert weight.dtype in [ + torch.int8, torch.int32 + ], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}." + + if weight.dtype == torch.int32: + assert weight.shape[ + -1] % 8 == 0, "the last dim of weight needs to be divided by 8." + packed_weight = torch_npu.npu_convert_weight_to_int4pack( + weight.flatten(0, 1)) + packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], + -1) + else: + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4." + packed_weight = weight.view(torch.int32).contiguous() + + return packed_weight + + +@register_scheme("W4A16", "moe") +class AscendW4A16FusedMoEMethod(AscendMoEScheme): + """FusedMoE method for Ascend W4A16.""" + + def __init__(self) -> None: + self.transpose_weight = True + self.num_bits = 4 # dtype = torch.int4 + self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32 + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 32) + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + + def get_weight( + self, + num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}" + assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}" + + param_dict = {} + + param_dict["w13_weight_packed"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.pack_factor, + dtype=torch.int32) + param_dict["w2_weight_packed"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32) + + return param_dict + + def get_dynamic_quant_param( + self, + num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}" + assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}" + + param_dict = {} + + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.bfloat16) + param_dict["w2_weight_scale"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.bfloat16) + param_dict["w13_weight_shape"] = torch.empty(num_experts, + 2, + dtype=torch.int32) + param_dict["w2_weight_shape"] = torch.empty(num_experts, + 2, + dtype=torch.int32) + param_dict["w13_weight_offset"] = torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.bfloat16) + param_dict["w2_weight_offset"] = torch.zeros( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.bfloat16) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight_packed, + w2=layer.w2_weight_packed, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_offset=layer.w13_weight_offset, + w2_offset=layer.w2_weight_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int4_w4a16=True, + expert_map=expert_map, + log2phy=log2phy, + shared_experts=shared_experts, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + dynamic_eplb=self.dynamic_eplb, + mc2_mask=kwargs.get("mc2_mask", None)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.transpose_weight: + w13_shape = layer.w13_weight_packed.data.shape + w2_shape = layer.w2_weight_packed.data.shape + unpacked_w13_weight = (unpack_from_int32( + layer.w13_weight_packed.data.flatten(0, 1), + torch.Size([ + w13_shape[0] * w13_shape[1], + w13_shape[2] * self.pack_factor + ]), + self.num_bits, + ).view(w13_shape[0], w13_shape[1], + -1).transpose(1, 2).contiguous().int()) + unpacked_w2_weight = (unpack_from_int32( + layer.w2_weight_packed.data.flatten(0, 1), + torch.Size([ + w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor + ]), + self.num_bits, + ).view(w2_shape[0], w2_shape[1], + -1).transpose(1, 2).contiguous().int()) + layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight) + layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight) + + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( + 1, 2).contiguous() + + layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose( + 1, 2).contiguous() + layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose( + 1, 2).contiguous() diff --git a/vllm_ascend/quantization/w4a4_flatquant_dynamic.py b/vllm_ascend/quantization/methods/w4a4_flatquant.py similarity index 93% rename from vllm_ascend/quantization/w4a4_flatquant_dynamic.py rename to vllm_ascend/quantization/methods/w4a4_flatquant.py index f13dae2fa30..773709fbee1 100644 --- a/vllm_ascend/quantization/w4a4_flatquant_dynamic.py +++ b/vllm_ascend/quantization/methods/w4a4_flatquant.py @@ -14,16 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import math from typing import Any, Dict, Optional, Tuple import torch import torch_npu +from .base import AscendLinearScheme +from .registry import register_scheme + KRONECKER_QUANT_MAX_BATCH_SIZE = 32768 def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor: + """Pack int4 weights for NPU.""" original_device = weight_tensor.device weight_tensor_npu = weight_tensor.npu() weight_int4_packed = torch_npu.npu_convert_weight_to_int4pack( @@ -32,6 +37,7 @@ def pack_int4_weights(weight_tensor: torch.Tensor) -> torch.Tensor: def get_decompose_dim(n): + """Get decomposed dimensions for Kronecker quantization.""" a = int(math.sqrt(n)) if a * a < n: a += 1 @@ -53,6 +59,7 @@ def batched_kronecker_quant( right_trans: torch.Tensor, clip_ratio: float, ) -> Tuple[torch.Tensor, torch.Tensor]: + """Batched Kronecker quantization with batch size limit handling.""" batch_tokens = x.shape[0] if batch_tokens <= KRONECKER_QUANT_MAX_BATCH_SIZE: return torch_npu.npu_kronecker_quant(x, @@ -75,7 +82,8 @@ def batched_kronecker_quant( return x_quantized_int4, activation_scale -class AscendW4A4FlatQuantDynamicLinearMethod: +@register_scheme("W4A4_FLATQUANT_DYNAMIC", "linear") +class AscendW4A4FlatQuantDynamicLinearMethod(AscendLinearScheme): """Linear method for Ascend W4A4_FLATQUANT_DYNAMIC. This class implements W4A4 quantization with FlatQuant approach and dynamic activation quantization. @@ -88,8 +96,7 @@ class AscendW4A4FlatQuantDynamicLinearMethod: def __init__(self): self.sym = True - @staticmethod - def get_weight(input_size: int, output_size: int, + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: if input_size % 8 != 0: raise ValueError( @@ -101,8 +108,7 @@ def get_weight(input_size: int, output_size: int, } return params_dict - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {} left_trans_dim, right_trans_dim = get_decompose_dim( AscendW4A4FlatQuantDynamicLinearMethod.input_size) @@ -115,8 +121,8 @@ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: params_dict["clip_ratio"] = torch.empty(1, dtype=torch.float32) return params_dict - @staticmethod def get_perchannel_param( + self, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: @@ -129,15 +135,8 @@ def get_perchannel_param( dtype=torch.float32) return params_dict - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: - return {} - - @staticmethod def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/methods/w4a8.py similarity index 97% rename from vllm_ascend/quantization/w4a8_dynamic.py rename to vllm_ascend/quantization/methods/w4a8.py index 167a42fcb88..04542db1386 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -29,10 +29,13 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import maybe_trans_nz +from .base import AscendLinearScheme, AscendMoEScheme +from .registry import register_scheme -class AscendW4A8DynamicLinearMethod: - """Linear method for Ascend W4A8_DYNAMIC - """ + +@register_scheme("W4A8_DYNAMIC", "linear") +class AscendW4A8DynamicLinearMethod(AscendLinearScheme): + """Linear method for Ascend W4A8_DYNAMIC.""" def __init__(self): vllm_config = get_current_vllm_config() @@ -72,23 +75,12 @@ def get_weight(self, input_size: int, output_size: int, return params_dict - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - - @staticmethod - def get_perchannel_param(output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - def get_pergroup_param(self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: Optional[str] = None) -> Dict[str, Any]: - """ - Create per-group quantization parameters. - """ + """Create per-group quantization parameters.""" params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, 1, @@ -121,8 +113,7 @@ def process_scale_second(weight: torch.Tensor, scale: torch.Tensor, per_group_scale: torch.Tensor, is_new_quant: bool = False): - """ - Process the scale for second-level quantization. + """Process the scale for second-level quantization. Args: weight: weight tensor [k, n] (in new version, n is already compressed to n/2) @@ -207,9 +198,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data.to(torch.int32)) -class AscendW4A8DynamicFusedMoEMethod: - """FusedMoe method for Ascend W4A8_DYNAMIC. - """ +@register_scheme("W4A8_DYNAMIC", "moe") +class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): + """FusedMoE method for Ascend W4A8_DYNAMIC.""" def __init__(self): self.ep_group = get_ep_group() @@ -340,7 +331,7 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - log2phy: torch.Tensor = None, + log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, diff --git a/vllm_ascend/quantization/w8a16.py b/vllm_ascend/quantization/methods/w8a16.py similarity index 84% rename from vllm_ascend/quantization/w8a16.py rename to vllm_ascend/quantization/methods/w8a16.py index 1e66c5e8420..97fc0468e3f 100644 --- a/vllm_ascend/quantization/w8a16.py +++ b/vllm_ascend/quantization/methods/w8a16.py @@ -22,17 +22,22 @@ from vllm_ascend.utils import maybe_trans_nz +from .base import AscendLinearScheme +from .registry import register_scheme -class AscendW8A16LinearMethod: - """Linear method for Ascend W8A16. +@register_scheme("W8A16", "linear") +class AscendW8A16LinearMethod(AscendLinearScheme): + """Linear method for Ascend W8A16. + + This scheme uses 8-bit quantized weights with 16-bit activations. """ def __init__(self) -> None: pass - @staticmethod def get_weight( + self, input_size: int, output_size: int, params_dtype: torch.dtype = torch.bfloat16, @@ -42,12 +47,8 @@ def get_weight( } return params_dict - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - - @staticmethod def get_perchannel_param( + self, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: @@ -60,15 +61,8 @@ def get_perchannel_param( dtype=params_dtype) return params_dict - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: - return {} - - @staticmethod def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py similarity index 94% rename from vllm_ascend/quantization/w8a8_dynamic.py rename to vllm_ascend/quantization/methods/w8a8_dynamic.py index b2e92e6e4e1..0bd770ee4a7 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -32,28 +32,39 @@ zero_experts_compute) from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz +from .base import AscendLinearScheme, AscendMoEScheme +from .registry import register_scheme -class AscendW8A8DynamicLinearMethod: + +def scale_from_float_to_int64(scale): + """Convert float32 scale to int64 representation.""" + import numpy as np + scale = torch.from_numpy( + np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), + dtype=np.int32).astype(np.int64)).to(scale.device) + return scale + + +@register_scheme("W8A8_DYNAMIC", "linear") +class AscendW8A8DynamicLinearMethod(AscendLinearScheme): """Linear method for Ascend W8A8_DYNAMIC. + + This scheme uses dynamic per-token quantization for activations + and per-channel quantization for weights. """ def __init__(self): pass - @staticmethod - def get_weight(input_size: int, output_size: int, + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = { "weight": torch.empty(output_size, input_size, dtype=torch.int8) } return params_dict - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - - @staticmethod def get_perchannel_param( + self, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: @@ -66,15 +77,8 @@ def get_perchannel_param( dtype=params_dtype) return params_dict - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: - return {} - - @staticmethod def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, @@ -100,9 +104,9 @@ def process_weights_after_loading(self, layer): layer.weight_offset.data = layer.weight_offset.data.flatten() -class AscendW8A8DynamicFusedMoEMethod: - """FusedMoe method for Ascend W8A8_DYNAMIC. - """ +@register_scheme("W8A8_DYNAMIC", "moe") +class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): + """FusedMoE method for Ascend W8A8_DYNAMIC.""" def __init__(self): self.ep_group = get_ep_group() @@ -128,9 +132,8 @@ def __init__(self): except AttributeError: self.moe_all_to_all_group_name = "" - @staticmethod - def get_weight(num_experts: int, intermediate_size_per_partition: int, - hidden_sizes: int, + def get_weight(self, num_experts: int, + intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: param_dict = {} param_dict["w13_weight"] = torch.empty(num_experts, @@ -144,8 +147,7 @@ def get_weight(num_experts: int, intermediate_size_per_partition: int, dtype=torch.int8) return param_dict - @staticmethod - def get_dynamic_quant_param(num_experts: int, + def get_dynamic_quant_param(self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: @@ -188,7 +190,7 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - log2phy: torch.Tensor = None, + log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, @@ -345,11 +347,3 @@ def process_weights_after_loading(self, layer): del layer.w2_weight_scale del layer.w2_weight_scale_fp32 torch.npu.empty_cache() - - -def scale_from_float_to_int64(scale): - import numpy as np - scale = torch.from_numpy( - np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(), - dtype=np.int32).astype(np.int64)).to(scale.device) - return scale diff --git a/vllm_ascend/quantization/w8a8_pdmix.py b/vllm_ascend/quantization/methods/w8a8_pdmix.py similarity index 56% rename from vllm_ascend/quantization/w8a8_pdmix.py rename to vllm_ascend/quantization/methods/w8a8_pdmix.py index 0fa74f7e9a0..a5fb570a95e 100644 --- a/vllm_ascend/quantization/w8a8_pdmix.py +++ b/vllm_ascend/quantization/methods/w8a8_pdmix.py @@ -1,37 +1,59 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from typing import Any, Dict, cast import torch from vllm.config import get_current_vllm_config -from .w8a8 import AscendW8A8LinearMethod +from .registry import register_scheme from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) +from .w8a8_static import AscendW8A8LinearMethod +@register_scheme("W8A8_MIX", "linear") class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod): + """Linear method for W8A8 prefill-decode mix. + + Uses static W8A8 for KV consumer (decode) and dynamic W8A8 for prefill. + """ def __init__(self): self.kv_transfer_config = get_current_vllm_config().kv_transfer_config super().__init__() - @staticmethod - def apply(layer, x, bias=None, tp_rank=0): + def apply(self, layer, x, bias=None, tp_rank=0): if layer.is_kv_consumer: - return AscendW8A8LinearMethod.apply(layer, x, bias, tp_rank) + return AscendW8A8LinearMethod.apply(self, layer, x, bias, tp_rank) else: - return AscendW8A8DynamicLinearMethod.apply(layer, x, bias, tp_rank) + return AscendW8A8DynamicLinearMethod.apply(self, layer, x, bias, + tp_rank) - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: - return AscendW8A8LinearMethod.get_pertensor_param(params_dtype) + def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: + return AscendW8A8LinearMethod.get_pertensor_param(self, params_dtype) - @staticmethod def get_perchannel_param( + self, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: return AscendW8A8LinearMethod.get_perchannel_param( - output_size, params_dtype) + self, output_size, params_dtype) def process_weights_after_loading(self, layer): AscendW8A8LinearMethod.process_weights_after_loading( @@ -40,18 +62,19 @@ def process_weights_after_loading(self, layer): layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer +@register_scheme("W8A8_MIX", "moe") class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod): + """FusedMoE method for W8A8 prefill-decode mix.""" def __init__(self): super().__init__() - @staticmethod - def get_dynamic_quant_param(num_experts: int, + def get_dynamic_quant_param(self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param( - num_experts, intermediate_size_per_partition, hidden_sizes, + self, num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype) param_dict["w2_deq_scale"] = torch.empty(num_experts, hidden_sizes, diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/methods/w8a8_static.py similarity index 88% rename from vllm_ascend/quantization/w8a8.py rename to vllm_ascend/quantization/methods/w8a8_static.py index 8809682e3b3..c848ed9a809 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/methods/w8a8_static.py @@ -24,27 +24,23 @@ get_ascend_device_type, get_weight_prefetch_method, maybe_trans_nz) +from .base import AscendLinearScheme +from .registry import register_scheme -def quant_per_tensor(in_tensor: torch.Tensor, - input_scale: torch.Tensor, - input_offset: torch.Tensor, - function=False): - return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, - torch.qint8, -1, function) +@register_scheme("W8A8", "linear") +class AscendW8A8LinearMethod(AscendLinearScheme): + """Linear method for Ascend W8A8 static quantization. -class AscendW8A8LinearMethod: - """Linear method for Ascend W8A8. - - Args: - w_sym: whether the linear weight is symmetrically quantized. + This scheme uses static per-tensor quantization for activations + and per-channel quantization for weights. """ def __init__(self) -> None: pass - @staticmethod def get_weight( + self, input_size: int, output_size: int, params_dtype: torch.dtype = torch.bfloat16, @@ -54,15 +50,14 @@ def get_weight( } return params_dict - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {} params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) return params_dict - @staticmethod def get_perchannel_param( + self, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: @@ -82,15 +77,8 @@ def get_perchannel_param( dtype=params_dtype) return params_dict - def get_pergroup_param(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - layer_type: Optional[str] = None) -> Dict[str, Any]: - return {} - - @staticmethod def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py new file mode 100644 index 00000000000..30b10787cae --- /dev/null +++ b/vllm_ascend/quantization/modelslim_config.py @@ -0,0 +1,408 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +"""ModelSlim quantization configuration and model mappings for Ascend. + +This module provides the AscendModelSlimConfig class for parsing quantization +configs generated by the ModelSlim tool, along with model-specific mappings. +""" + +from types import MappingProxyType +from typing import Any, Dict, List, Mapping, Optional + +import torch +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization import \ + register_quantization_config +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, VocabParallelEmbedding) +from vllm.model_executor.models.utils import WeightsMapper + +from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod +from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD + +from .methods import get_scheme_class + +logger = init_logger(__name__) + + +# key: model_type +# value: orig_to_new_prefix +QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = { + "qwen3_vl_moe": { + "visual.": "model.visual.", + "language_model.lm_head.": "lm_head.", + "language_model.model.": "model.language_model.", + }, +} + +# key: model_type +# value: dict of fused module name -> list of original module names +packed_modules_model_mapping: Dict[str, Dict[str, List[str]]] = { + "qwen3_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + }, + "deepseek_v2": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + "deepseek_v3": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + "pangu_ultra_moe": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + "kimi_k2": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + "deepseek_v32": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; + # NOTE 2.The description file generated by the current msmodelslim tool does not have + # MTP layer info. Please manually add it and set the value to FLOAT. + "deepseek_mtp": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + "pangu_ultra_moe_mtp": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, + "qwen3_next": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": ["gate_proj", "up_proj"], + "in_proj": ["in_proj_qkvz", "in_proj_ba"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + "qwen2_5_vl": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + }, + "qwen3_vl_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + }, + "glm4_moe": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, + "longcat_flash": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, +} + + +def get_packed_modules_mapping(model_type: str) -> Dict[str, List[str]]: + """Get packed modules mapping for a model type. + + Args: + model_type: The model type string (e.g., "deepseek_v3"). + + Returns: + Dictionary mapping fused module names to their component module names. + Returns empty dict if model_type is not found. + """ + return packed_modules_model_mapping.get(model_type, {}) + + +def get_prefix_mapping(model_type: str) -> Dict[str, str]: + """Get prefix mapping for a model type. + + Args: + model_type: The model type string (e.g., "qwen3_vl_moe"). + + Returns: + Dictionary mapping original prefixes to new prefixes. + Returns empty dict if model_type is not found. + """ + return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {}) + + +def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, + packed_modules_mapping: Dict[str, Any]) -> str: + """Determine the quantization type for a linear layer. + + Args: + quant_description: The quantization description dictionary. + prefix: The layer prefix. + packed_modules_mapping: Mapping for packed/fused modules. + + Returns: + The quantization type string (e.g., "W8A8_DYNAMIC"). + """ + proj_name = prefix.split(".")[-1] + if proj_name in packed_modules_mapping: + quant_type = None + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in packed_modules_mapping[proj_name] + ] + for shard_prefix in shard_prefixes: + shard_quant_type = quant_description[shard_prefix + '.weight'] + + if quant_type is None: + quant_type = shard_quant_type + elif shard_quant_type != quant_type: + raise ValueError( + f"Not all shards of {prefix} are quantized with same quant type." + f"Shard {proj_name} uses {shard_quant_type}, but another shard" + f"use {quant_type}. Please check quantization config.") + else: + quant_type = quant_description[prefix + '.weight'] + return quant_type + + +def get_quant_method_modelslim( + quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + """Get quantization method for ModelSlim models. + + Args: + quant_description: The quantization description dictionary. + prefix: The layer prefix. + layer_type: The type of layer ("linear", "moe", "attention"). + packed_modules_mapping: Mapping for packed/fused modules. + + Returns: + An instance of the appropriate quantization method class. + """ + logger.info_once("Using the vLLM Ascend modelslim Quantization now!") + if packed_modules_mapping is None: + packed_modules_mapping = dict() + # Attention + if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): + quant_type = quant_description['fa_quant_type'] + # Linear + else: + quant_type = get_linear_quant_type(quant_description, prefix, + packed_modules_mapping) + + # Use registry to get scheme class + method_cls = get_scheme_class(quant_type, layer_type) + if method_cls is not None: + return method_cls() + + raise NotImplementedError( + f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." + ) + + +@register_quantization_config(ASCEND_QUANTIZATION_METHOD) +class AscendModelSlimConfig(QuantizationConfig): + """Config class for Ascend ModelSlim quantization. + + This class is a general class that parses quantization configs + that are supported on Ascend hardware, specifically for models + quantized using the ModelSlim tool. + """ + + def __init__(self, quant_config: Dict[str, Any]): + super().__init__() + self.quant_description = quant_config + # TODO(whx): remove this adaptation after adding "shared_head" + # to prefix of DeepSeekShareHead in vLLM. + extra_quant_dict = {} + for k in self.quant_description.keys(): + if "shared_head" in k: + new_k = k.replace(".shared_head.", ".") + extra_quant_dict[new_k] = self.quant_description[k] + if "weight_packed" in k: + new_k = k.replace("weight_packed", "weight") + extra_quant_dict[new_k] = self.quant_description[k] + self.quant_description.update(extra_quant_dict) + + def __repr__(self) -> str: + return "AscendModelSlimConfig:\n" + super().__repr__() + + @classmethod + def get_name(cls) -> str: + return ASCEND_QUANTIZATION_METHOD + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quant_model_description.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AscendModelSlimConfig": + return cls(config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + if hf_quant_cfg is not None: + quant_method = hf_quant_cfg.get("quant_method", None) + if not quant_method and torch.npu.is_available(): + return ASCEND_QUANTIZATION_METHOD + return None + + def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: + # TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented + prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) + if prefix_mapping: + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix=prefix_mapping) + return hf_to_vllm_mapper._map_name(prefix) + return prefix + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from .wrappers import (AscendEmbeddingMethod, AscendFusedMoEMethod, + AscendKVCacheMethod, AscendLinearMethod) + + vllm_config = get_current_vllm_config() + model_type = vllm_config.model_config.hf_config.model_type + if model_type in packed_modules_model_mapping: + self.packed_modules_mapping = packed_modules_model_mapping[ + model_type] + prefix = self.quant_prefix_mapper(model_type, prefix) + from vllm.attention.layer import Attention + if prefix.startswith("language_model"): + prefix = prefix.split('.', 1)[-1] + if isinstance(layer, LinearBase): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return AscendUnquantizedLinearMethod() + return AscendLinearMethod(self, prefix, + self.packed_modules_mapping, layer) + elif isinstance(layer, Attention) and \ + 'fa_quant_type' in self.quant_description.keys() and \ + self.quant_description['fa_quant_type'] is not None: + return AscendKVCacheMethod(self, prefix) + elif isinstance(layer, FusedMoE): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return AscendUnquantizedFusedMoEMethod(layer.moe_config) + return AscendFusedMoEMethod(self, prefix, + self.packed_modules_mapping, layer) + elif isinstance(layer, VocabParallelEmbedding): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return UnquantizedEmbeddingMethod() + return AscendEmbeddingMethod(self, prefix, + self.packed_modules_mapping, layer) + return None + + def is_layer_skipped_ascend( + self, + prefix: str, + fused_mapping: Mapping[str, List[str]] = MappingProxyType({})): + # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped + proj_name = prefix.split(".")[-1] + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = self.quant_description[shard_prefix + + '.weight'] == "FLOAT" + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" + + assert is_skipped is not None + return is_skipped + + def get_scaled_act_names(self) -> List[str]: + return [] diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py deleted file mode 100644 index 71db5269b09..00000000000 --- a/vllm_ascend/quantization/utils.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Any, Dict, Optional, Type - -import torch -from vllm.logger import logger - -from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD - -from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod -from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, - AscendW4A8DynamicLinearMethod) -from .w4a16 import AscendW4A16FusedMoEMethod -from .w8a8 import AscendW8A8LinearMethod -from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, - AscendW8A8DynamicLinearMethod) -from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, - AscendW8A8PDMixLinearMethod) -from .w8a16 import AscendW8A16LinearMethod - -ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { - "W4A16": { - "moe": AscendW4A16FusedMoEMethod, - }, - "W4A8_DYNAMIC": { - "linear": AscendW4A8DynamicLinearMethod, - "moe": AscendW4A8DynamicFusedMoEMethod, - }, - "W4A4_FLATQUANT_DYNAMIC": { - "linear": AscendW4A4FlatQuantDynamicLinearMethod, - }, - "W8A8": { - "linear": AscendW8A8LinearMethod, - }, - "W8A8_DYNAMIC": { - "linear": AscendW8A8DynamicLinearMethod, - "moe": AscendW8A8DynamicFusedMoEMethod, - }, - "W8A8_MIX": { - "linear": AscendW8A8PDMixLinearMethod, - "moe": AscendW8A8PDMixFusedMoeMethod, - }, - "W8A16": { - "linear": AscendW8A16LinearMethod, - } -} - - -def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, - packed_modules_mapping: Dict[str, Any]): - proj_name = prefix.split(".")[-1] - if proj_name in packed_modules_mapping: - quant_type = None - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in packed_modules_mapping[proj_name] - ] - for shard_prefix in shard_prefixes: - shard_quant_type = quant_description[shard_prefix + '.weight'] - - if quant_type is None: - quant_type = shard_quant_type - elif shard_quant_type != quant_type: - raise ValueError( - f"Not all shards of {prefix} are quantized with same quant type." - f"Shard {proj_name} uses {shard_quant_type}, but another shard" - f"use {quant_type}. Please check quantization config.") - else: - quant_type = quant_description[prefix + '.weight'] - return quant_type - - -def get_quant_method(quant_description: Dict[str, Any], - prefix: str, - layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None, - layer: torch.nn.Module = None): - if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD: - return get_quant_method_llmcompressor(layer) - - return get_quant_method_modelslim(quant_description, prefix, layer_type, - packed_modules_mapping) - - -def get_quant_method_llmcompressor(layer: torch.nn.Module): - logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") - if layer.scheme is None: - raise ValueError("A scheme must be defined for each layer") - return layer.scheme - - -def get_quant_method_modelslim( - quant_description: Dict[str, Any], - prefix: str, - layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): - logger.info_once("Using the vLLM Ascend modelslim Quantization now!") - if packed_modules_mapping is None: - packed_modules_mapping = dict() - # Attention - if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): - quant_type = quant_description['fa_quant_type'] - # Linear - else: - quant_type = get_linear_quant_type(quant_description, prefix, - packed_modules_mapping) - if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys(): - method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type] - if layer_type in method_map.keys(): - method_cls = method_map[layer_type] - return method_cls() - else: - raise NotImplementedError( - f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." - ) - raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ - f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}") diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/wrappers.py similarity index 52% rename from vllm_ascend/quantization/quant_config.py rename to vllm_ascend/quantization/wrappers.py index f6a9824161f..193dbdb47f7 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/wrappers.py @@ -15,24 +15,19 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from types import MappingProxyType -from typing import Any, Callable, Dict, List, Mapping, Optional +"""Wrapper classes that delegate to actual quantization scheme implementations.""" + +from typing import Any, Callable, Dict, List, Optional import torch -from vllm.config import get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_rank -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, +from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, RowParallelLinear) -from vllm.model_executor.layers.quantization import \ - register_quantization_config -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizeMethodBase from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.vocab_parallel_embedding import ( - UnquantizedEmbeddingMethod, VocabParallelEmbedding) -from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs @@ -40,285 +35,57 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) -from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, flashcomm2_enable, +from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, flashcomm2_enable, mlp_tp_enable, oproj_tp_enable) -from .utils import get_quant_method - - -@register_quantization_config(ASCEND_QUANTIZATION_METHOD) -class AscendQuantConfig(QuantizationConfig): - """Config class for Ascend - This class is a general class that parse quantization configs - that are supported on ascend hardware. +def get_quant_method(quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None, + layer: Optional[torch.nn.Module] = None): + """Get the appropriate quantization method for a layer. + + This is the routing function that dispatches to either ModelSlim or + LLM-Compressor implementations based on the quant_description. + + Args: + quant_description: The quantization description dictionary. + prefix: The layer prefix. + layer_type: The type of layer ("linear", "moe", "attention"). + packed_modules_mapping: Mapping for packed/fused modules. + layer: The layer module (optional). + + Returns: + An instance of the appropriate quantization method class. """ + if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD: + from .compressed_tensors_config import get_quant_method_llmcompressor + return get_quant_method_llmcompressor(layer) - def __init__(self, quant_config: Dict[str, Any]): - super().__init__() - self.quant_description = quant_config - # TODO(whx): remove this adaptation after adding "shared_head" - # to prefix of DeepSeekShareHead in vLLM. - extra_quant_dict = {} - for k in self.quant_description.keys(): - if "shared_head" in k: - new_k = k.replace(".shared_head.", ".") - extra_quant_dict[new_k] = self.quant_description[k] - if "weight_packed" in k: - new_k = k.replace("weight_packed", "weight") - extra_quant_dict[new_k] = self.quant_description[k] - self.quant_description.update(extra_quant_dict) - - def __repr__(self) -> str: - return "AscendQuantConfig:\n" + super().__repr__() - - @classmethod - def get_name(cls) -> str: - return ASCEND_QUANTIZATION_METHOD - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.int8, torch.float16, torch.bfloat16] - - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError( - "Ascend hardware dose not support \"get_min_capability\" feature.") - - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["quant_model_description.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": - return cls(config) - - @classmethod - def override_quantization_method(cls, hf_quant_cfg, - user_quant) -> Optional[str]: - if hf_quant_cfg is not None: - quant_method = hf_quant_cfg.get("quant_method", None) - if not quant_method and torch.npu.is_available(): - return ASCEND_QUANTIZATION_METHOD - return None - - def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: - # TODO (Levi-JQ): will be removed when QuantizationConfig.apply_vllm_mapper is implemented - prefix_mapping = QUANT_MODEL_PREFIX_MAPPINGS.get(model_type) - if prefix_mapping: - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix=prefix_mapping) - return hf_to_vllm_mapper._map_name(prefix) - return prefix - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["QuantizeMethodBase"]: - vllm_config = get_current_vllm_config() - model_type = vllm_config.model_config.hf_text_config.model_type - if model_type in packed_modules_model_mapping: - self.packed_modules_mapping = packed_modules_model_mapping[ - model_type] - prefix = self.quant_prefix_mapper(model_type, prefix) - from vllm.attention.layer import Attention - if prefix.startswith("language_model"): - prefix = prefix.split('.', 1)[-1] - if isinstance(layer, LinearBase): - if self.is_layer_skipped_ascend(prefix, - self.packed_modules_mapping): - return AscendUnquantizedLinearMethod() - return AscendLinearMethod(self, prefix, - self.packed_modules_mapping, layer) - elif isinstance(layer, Attention) and \ - 'fa_quant_type' in self.quant_description.keys() and \ - self.quant_description['fa_quant_type'] is not None: - return AscendKVCacheMethod(self, prefix) - elif isinstance(layer, FusedMoE): - if self.is_layer_skipped_ascend(prefix, - self.packed_modules_mapping): - return AscendUnquantizedFusedMoEMethod(layer.moe_config) - return AscendFusedMoEMethod(self, prefix, - self.packed_modules_mapping, layer) - elif isinstance(layer, VocabParallelEmbedding): - if self.is_layer_skipped_ascend(prefix, - self.packed_modules_mapping): - return UnquantizedEmbeddingMethod() - return AscendEmbeddingMethod(self, prefix, - self.packed_modules_mapping, layer) - return None - - def is_layer_skipped_ascend( - self, - prefix: str, - fused_mapping: Mapping[str, List[str]] = MappingProxyType({})): - # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped - proj_name = prefix.split(".")[-1] - if proj_name in fused_mapping: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in fused_mapping[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = self.quant_description[shard_prefix + - '.weight'] == "FLOAT" - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision.") - else: - is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" - - assert is_skipped is not None - return is_skipped - - def get_scaled_act_names(self) -> List[str]: - return [] - - -# key: model_type -# value: orig_to_new_prefix -QUANT_MODEL_PREFIX_MAPPINGS = { - "qwen3_vl_moe": { - "visual.": "model.visual.", - "language_model.lm_head.": "lm_head.", - "language_model.model.": "model.language_model.", - }, -} - -packed_modules_model_mapping = { - "qwen3_moe": { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - }, - "deepseek_v2": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] - }, - "deepseek_v3": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] - }, - "pangu_ultra_moe": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] - }, - "kimi_k2": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] - }, - "deepseek_v32": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] - }, - # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; - # NOTE 2.The description file generated by the current msmodelslim tool does not have - # MTP layer info. Please manually add it and set the value to FLOAT. - "deepseek_mtp": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - }, - "pangu_ultra_moe_mtp": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] - }, - "qwen3_next": { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": ["gate_proj", "up_proj"], - "in_proj": ["in_proj_qkvz", "in_proj_ba"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - }, - "qwen2_5_vl": { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - }, - "qwen3_vl_moe": { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - }, - "glm4_moe": { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] - }, - "longcat_flash": { - "gate_up_proj": ["gate_proj", "up_proj"], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], - "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] - }, -} + from .modelslim_config import get_quant_method_modelslim + return get_quant_method_modelslim(quant_description, prefix, layer_type, + packed_modules_mapping) class AscendLinearMethod(LinearMethodBase): """Linear method for Ascend quantization. + This wrapper class delegates to the actual quantization scheme implementation + based on the quant_config and prefix. + Args: quant_config: The Ascend quantization config. + prefix: The layer prefix for determining quantization type. + packed_modules_mapping: Mapping for packed/fused modules. + layer: The layer module (optional). """ def __init__(self, - quant_config: AscendQuantConfig, + quant_config: "QuantizeMethodBase", prefix: str, - packed_modules_mapping: Dict[str, Any] | None, - layer: torch.nn.Module = None) -> None: + packed_modules_mapping: Optional[Dict[str, Any]], + layer: Optional[torch.nn.Module] = None) -> None: self.quant_method = get_quant_method(quant_config.quant_description, prefix, "linear", @@ -431,9 +198,11 @@ class AscendKVCacheMethod(BaseKVCacheMethod): Args: quant_config: The Ascend quantization config. + prefix: The layer prefix. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: + def __init__(self, quant_config: "QuantizeMethodBase", + prefix: str) -> None: self.quant_method = get_quant_method(quant_config.quant_description, prefix, "attention") @@ -459,11 +228,14 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): Args: quant_config: The Ascend quantization config. + prefix: The layer prefix. + packed_modules_mapping: Mapping for packed/fused modules. + layer: The layer module. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, - Any], layer: torch.nn.Module): + def __init__(self, quant_config: "QuantizeMethodBase", prefix: str, + packed_modules_mapping: Optional[Dict[str, Any]], + layer: torch.nn.Module): super().__init__(layer.moe_config) self.quant_method = get_quant_method(quant_config.quant_description, prefix, @@ -524,7 +296,7 @@ def apply( e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, - log2phy: torch.Tensor = None, + log2phy: Optional[torch.Tensor] = None, global_redundant_expert_num=0, **kwargs, ) -> torch.Tensor: @@ -552,12 +324,18 @@ def supports_eplb(self): class AscendEmbeddingMethod(AscendLinearMethod): """Embedding method for Ascend quantization. - Args: - quant_config: The Ascend quantization config. + This is essentially the same as AscendLinearMethod, just with a different name + for clarity when used with VocabParallelEmbedding layers. + + Args: + quant_config: The Ascend quantization config. + prefix: The layer prefix. + packed_modules_mapping: Mapping for packed/fused modules. + layer: The layer module. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str, - packed_modules_mapping: Dict[str, Any], + def __init__(self, quant_config: "QuantizeMethodBase", prefix: str, + packed_modules_mapping: Optional[Dict[str, Any]], layer: torch.nn.Module) -> None: self.quant_method = get_quant_method(quant_config.quant_description, prefix, From d2c5dd173f3ede037d95f282f77180194c9345f2 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Fri, 9 Jan 2026 14:18:17 +0800 Subject: [PATCH 02/20] fix lint Signed-off-by: SlightwindSec --- .../quantization/methods/w8a8_pdmix.py | 76 ++++++++++++------- vllm_ascend/quantization/modelslim_config.py | 2 +- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/vllm_ascend/quantization/methods/w8a8_pdmix.py b/vllm_ascend/quantization/methods/w8a8_pdmix.py index a5fb570a95e..d4e27b2159e 100644 --- a/vllm_ascend/quantization/methods/w8a8_pdmix.py +++ b/vllm_ascend/quantization/methods/w8a8_pdmix.py @@ -14,12 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +"""W8A8 Prefill-Decode Mix quantization methods. -from typing import Any, Dict, cast +This module provides quantization methods that use different strategies +for prefill and decode phases: +- Prefill: Uses dynamic W8A8 quantization +- Decode (KV consumer): Uses static W8A8 quantization +""" + +from typing import Any, Dict, Optional import torch from vllm.config import get_current_vllm_config +from .base import AscendLinearScheme from .registry import register_scheme from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) @@ -27,54 +35,70 @@ @register_scheme("W8A8_MIX", "linear") -class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod): - """Linear method for W8A8 prefill-decode mix. - - Uses static W8A8 for KV consumer (decode) and dynamic W8A8 for prefill. +class AscendW8A8PDMixLinearMethod(AscendLinearScheme): + """Linear method for W8A8 prefill-decode mix quantization. + + This scheme uses composition to delegate to the appropriate quantization + method based on the execution phase: + - Static W8A8 for KV consumer (decode phase) + - Dynamic W8A8 for prefill phase + + The static method is used for weight/parameter specifications since + it requires more parameters (input_scale, deq_scale, etc.) that are + needed for static quantization during decode. """ def __init__(self): - self.kv_transfer_config = get_current_vllm_config().kv_transfer_config - super().__init__() + self._static_method = AscendW8A8LinearMethod() + self._dynamic_method = AscendW8A8DynamicLinearMethod() - def apply(self, layer, x, bias=None, tp_rank=0): - if layer.is_kv_consumer: - return AscendW8A8LinearMethod.apply(self, layer, x, bias, tp_rank) - else: - return AscendW8A8DynamicLinearMethod.apply(self, layer, x, bias, - tp_rank) + kv_transfer_config = get_current_vllm_config().kv_transfer_config + self._is_kv_consumer = (kv_transfer_config is not None + and kv_transfer_config.is_kv_consumer) + + def get_weight(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + return self._static_method.get_weight(input_size, output_size, + params_dtype) def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: - return AscendW8A8LinearMethod.get_pertensor_param(self, params_dtype) + return self._static_method.get_pertensor_param(params_dtype) def get_perchannel_param( self, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - return AscendW8A8LinearMethod.get_perchannel_param( - self, output_size, params_dtype) + return self._static_method.get_perchannel_param(output_size, + params_dtype) - def process_weights_after_loading(self, layer): - AscendW8A8LinearMethod.process_weights_after_loading( - cast(AscendW8A8LinearMethod, self), layer) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + tp_rank: Optional[int] = 0, + ) -> torch.Tensor: + if layer.is_kv_consumer: + return self._static_method.apply(layer, x, bias, tp_rank) + else: + return self._dynamic_method.apply(layer, x, bias, tp_rank) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self._static_method.process_weights_after_loading(layer) layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) - layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer + layer.is_kv_consumer = self._is_kv_consumer @register_scheme("W8A8_MIX", "moe") class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod): - """FusedMoE method for W8A8 prefill-decode mix.""" - - def __init__(self): - super().__init__() def get_dynamic_quant_param(self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype) -> Dict[str, Any]: - param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param( - self, num_experts, intermediate_size_per_partition, hidden_sizes, + param_dict = super().get_dynamic_quant_param( + num_experts, intermediate_size_per_partition, hidden_sizes, params_dtype) param_dict["w2_deq_scale"] = torch.empty(num_experts, hidden_sizes, diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 30b10787cae..508884829d9 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -200,7 +200,7 @@ def get_prefix_mapping(model_type: str) -> Dict[str, str]: def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, - packed_modules_mapping: Dict[str, Any]) -> str: + packed_modules_mapping: Dict[str, Any]) -> Optional[str]: """Determine the quantization type for a linear layer. Args: From 58f5902fa7d968e3b2c65ad335380632372b4a96 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Fri, 9 Jan 2026 14:28:51 +0800 Subject: [PATCH 03/20] fix lint Signed-off-by: SlightwindSec --- vllm_ascend/quantization/methods/w8a8_pdmix.py | 4 ++-- vllm_ascend/quantization/modelslim_config.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/quantization/methods/w8a8_pdmix.py b/vllm_ascend/quantization/methods/w8a8_pdmix.py index d4e27b2159e..6c7cd394f37 100644 --- a/vllm_ascend/quantization/methods/w8a8_pdmix.py +++ b/vllm_ascend/quantization/methods/w8a8_pdmix.py @@ -69,8 +69,8 @@ def get_perchannel_param( output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - return self._static_method.get_perchannel_param(output_size, - params_dtype) + return self._static_method.get_perchannel_param( + output_size, params_dtype) def apply( self, diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 508884829d9..1ab0430ac88 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -45,7 +45,6 @@ logger = init_logger(__name__) - # key: model_type # value: orig_to_new_prefix QUANT_MODEL_PREFIX_MAPPINGS: Dict[str, Dict[str, str]] = { @@ -199,8 +198,9 @@ def get_prefix_mapping(model_type: str) -> Dict[str, str]: return QUANT_MODEL_PREFIX_MAPPINGS.get(model_type, {}) -def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, - packed_modules_mapping: Dict[str, Any]) -> Optional[str]: +def get_linear_quant_type( + quant_description: Dict[str, Any], prefix: str, + packed_modules_mapping: Dict[str, Any]) -> Optional[str]: """Determine the quantization type for a linear layer. Args: From b8cb56a639fd1e96198d807f2ec553ae20d72feb Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Fri, 9 Jan 2026 16:05:15 +0800 Subject: [PATCH 04/20] fix circular import Signed-off-by: SlightwindSec --- vllm_ascend/quantization/compressed_tensors_config.py | 5 ++++- vllm_ascend/quantization/modelslim_config.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index a15bdfffa0d..1cc9c8d7199 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -37,7 +37,6 @@ should_ignore_layer) from vllm.model_executor.models.utils import WeightsMapper -from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD logger = init_logger(__name__) @@ -184,6 +183,10 @@ def get_quant_method( None, layer) return quant_method if isinstance(layer, FusedMoE): + # Delayed import to avoid circular import + from vllm_ascend.ops.fused_moe.fused_moe import ( + AscendUnquantizedFusedMoEMethod) + layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD # collect schemes quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 1ab0430ac88..3e41d9b2a91 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -37,8 +37,6 @@ UnquantizedEmbeddingMethod, VocabParallelEmbedding) from vllm.model_executor.models.utils import WeightsMapper -from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from .methods import get_scheme_class @@ -353,6 +351,8 @@ def get_quant_method(self, layer: torch.nn.Module, if isinstance(layer, LinearBase): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): + # Delayed import to avoid circular import + from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod return AscendUnquantizedLinearMethod() return AscendLinearMethod(self, prefix, self.packed_modules_mapping, layer) @@ -363,6 +363,9 @@ def get_quant_method(self, layer: torch.nn.Module, elif isinstance(layer, FusedMoE): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): + # Delayed import to avoid circular import + from vllm_ascend.ops.fused_moe.fused_moe import ( + AscendUnquantizedFusedMoEMethod) return AscendUnquantizedFusedMoEMethod(layer.moe_config) return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping, layer) From 96a7dc6f23a03187104e15ab90c046d714dd7148 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Fri, 9 Jan 2026 18:30:02 +0800 Subject: [PATCH 05/20] add mxfp8 refactor Signed-off-by: SlightwindSec --- .../quantization/compressed_tensors_config.py | 4 ++-- vllm_ascend/quantization/methods/__init__.py | 21 +++++++++++++++++++ .../{w8a8mxfp8.py => methods/w8a8_mxfp8.py} | 21 ++++++++++++------- vllm_ascend/quantization/modelslim_config.py | 7 ++++--- 4 files changed, 41 insertions(+), 12 deletions(-) rename vllm_ascend/quantization/{w8a8mxfp8.py => methods/w8a8_mxfp8.py} (81%) diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 1cc9c8d7199..18415131250 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -184,8 +184,8 @@ def get_quant_method( return quant_method if isinstance(layer, FusedMoE): # Delayed import to avoid circular import - from vllm_ascend.ops.fused_moe.fused_moe import ( - AscendUnquantizedFusedMoEMethod) + from vllm_ascend.ops.fused_moe.fused_moe import \ + AscendUnquantizedFusedMoEMethod layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD # collect schemes diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index bcf73cfd239..8be021e71ad 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -27,6 +27,8 @@ scheme = scheme_cls() """ +from typing import Any + # Import base classes from .base import AscendLinearScheme, AscendMoEScheme # Import registry functions @@ -37,11 +39,26 @@ from .w4a16 import AscendW4A16FusedMoEMethod from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) +from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, AscendW8A8PDMixLinearMethod) from .w8a8_static import AscendW8A8LinearMethod from .w8a16 import AscendW8A16LinearMethod + +def is_mx_quant_type(instance: Any) -> bool: + """Checks if the quantization method is a microscaling (MX) type. + + Args: + instance: The quantization method instance to check. + + Returns: + True if the instance is an MX quantization type, False otherwise. + """ + MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, ) + return isinstance(instance, MX_QUANT_TYPES) + + __all__ = [ # Base classes "AscendLinearScheme", @@ -49,11 +66,15 @@ # Registry functions "register_scheme", "get_scheme_class", + # Utility functions + "is_mx_quant_type", # W8A8 static "AscendW8A8LinearMethod", # W8A8 dynamic "AscendW8A8DynamicLinearMethod", "AscendW8A8DynamicFusedMoEMethod", + # W8A8 MXFP8 + "AscendW8A8MXFP8DynamicLinearMethod", # W8A8 PDMix "AscendW8A8PDMixLinearMethod", "AscendW8A8PDMixFusedMoeMethod", diff --git a/vllm_ascend/quantization/w8a8mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py similarity index 81% rename from vllm_ascend/quantization/w8a8mxfp8.py rename to vllm_ascend/quantization/methods/w8a8_mxfp8.py index 2997a176d70..f47886cca70 100644 --- a/vllm_ascend/quantization/w8a8mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +"""W8A8 MXFP8 Dynamic quantization scheme for Ascend NPU.""" from typing import Any, Dict, Optional @@ -21,9 +22,17 @@ import torch_npu from vllm.config import get_current_vllm_config +from .base import AscendLinearScheme +from .registry import register_scheme -class AscendW8A8MXFP8DynamicLinearMethod: - """Linear method for Ascend W8A8_DYNAMIC. + +@register_scheme("W8A8_MXFP8", "linear") +class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): + """Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization. + + This scheme uses microscaling FP8 quantization with per-group scales. + The activation is dynamically quantized to FP8 (E4M3FN format) with + microscaling, and weights are stored in FP8 format with per-group scales. """ model_dtype = None @@ -32,8 +41,7 @@ def __init__(self): self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 32) - @staticmethod - def get_weight(input_size: int, output_size: int, + def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = { "weight": @@ -41,12 +49,11 @@ def get_weight(input_size: int, output_size: int, } return params_dict - @staticmethod - def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: + def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: return {} - @staticmethod def get_perchannel_param( + self, output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 3e41d9b2a91..f119b6419cc 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -352,7 +352,8 @@ def get_quant_method(self, layer: torch.nn.Module, if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): # Delayed import to avoid circular import - from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod + from vllm_ascend.ops.linear import \ + AscendUnquantizedLinearMethod return AscendUnquantizedLinearMethod() return AscendLinearMethod(self, prefix, self.packed_modules_mapping, layer) @@ -364,8 +365,8 @@ def get_quant_method(self, layer: torch.nn.Module, if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): # Delayed import to avoid circular import - from vllm_ascend.ops.fused_moe.fused_moe import ( - AscendUnquantizedFusedMoEMethod) + from vllm_ascend.ops.fused_moe.fused_moe import \ + AscendUnquantizedFusedMoEMethod return AscendUnquantizedFusedMoEMethod(layer.moe_config) return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping, layer) From dc1cc7d8bc4b78ebd7d177a59d046885050eb22f Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Fri, 9 Jan 2026 18:34:31 +0800 Subject: [PATCH 06/20] remove mxfp8 get_perchannel_param Signed-off-by: SlightwindSec --- vllm_ascend/quantization/methods/w8a8_mxfp8.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index f47886cca70..e42abfae638 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""W8A8 MXFP8 Dynamic quantization scheme for Ascend NPU.""" from typing import Any, Dict, Optional @@ -49,16 +48,6 @@ def get_weight(self, input_size: int, output_size: int, } return params_dict - def get_pertensor_param(self, params_dtype: torch.dtype) -> Dict[str, Any]: - return {} - - def get_perchannel_param( - self, - output_size: int, - params_dtype: torch.dtype, - ) -> Dict[str, Any]: - return {} - def get_pergroup_param(self, input_size: int, output_size: int, From e78ba76f809df4ec7a22ad4df4958e8a508b4f8e Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Sat, 10 Jan 2026 10:38:00 +0800 Subject: [PATCH 07/20] fix circular import Signed-off-by: SlightwindSec --- .../ut/quantization/test_modelslim_config.py | 8 ++-- vllm_ascend/ops/fused_moe/fused_moe.py | 18 ++++---- vllm_ascend/quantization/methods/__init__.py | 44 ++++++------------- vllm_ascend/quantization/methods/base.py | 15 +++++++ vllm_ascend/quantization/methods/w4a8.py | 5 ++- .../quantization/methods/w8a8_dynamic.py | 5 ++- 6 files changed, 50 insertions(+), 45 deletions(-) diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index f290e74259d..58d011b2498 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -92,7 +92,7 @@ def test_get_quant_method_for_linear(self): # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.modelslim_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: + patch('vllm_ascend.quantization.wrappers.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: method = self.ascend_config.get_quant_method(linear_layer, ".attn") self.assertIs(method, mock_ascend_linear.return_value) @@ -105,7 +105,7 @@ def test_get_quant_method_for_attention(self): mock_config = MagicMock() mock_config.model_config.hf_config.model_type = None with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.modelslim_config.AscendKVCacheMethod', \ + patch('vllm_ascend.quantization.wrappers.AscendKVCacheMethod', \ return_value=MagicMock()) as mock_ascend_kvcache: # Test with fa_quant_type method = self.ascend_config.get_quant_method( @@ -122,7 +122,7 @@ def test_get_quant_method_for_fused_moe(self): # Test skipped layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \ patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.modelslim_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + patch('vllm_ascend.ops.fused_moe.fused_moe.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") self.assertIs(method, mock_ascend_moe.return_value) @@ -130,7 +130,7 @@ def test_get_quant_method_for_fused_moe(self): # Test quantized layer with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ - patch('vllm_ascend.quantization.modelslim_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + patch('vllm_ascend.quantization.wrappers.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") self.assertIs(method, mock_ascend_moe.return_value) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index e5e661fa1ba..e6ca08281b0 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -40,8 +40,6 @@ FusedExpertsResult, setup_moe_comm_method) from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType -from vllm_ascend.quantization.methods import (AscendW4A8DynamicFusedMoEMethod, - AscendW8A8DynamicFusedMoEMethod) from vllm_ascend.utils import (AscendDeviceType, enable_sp, get_ascend_device_type, maybe_trans_nz, npu_stream_switch, shared_expert_dp_enabled, @@ -236,12 +234,16 @@ def _get_quant_type(self) -> QuantType: method = quant_method.quant_method - if isinstance(method, AscendW8A8DynamicFusedMoEMethod): - return QuantType.W8A8 - elif isinstance(method, AscendW4A8DynamicFusedMoEMethod): - return QuantType.W4A8 - else: - return QuantType.NONE + if hasattr(method, "quant_type"): + from vllm_ascend.quantization.methods.base import \ + QuantType as SchemeQuantType + scheme_quant_type = method.quant_type + if scheme_quant_type == SchemeQuantType.W8A8: + return QuantType.W8A8 + elif scheme_quant_type == SchemeQuantType.W4A8: + return QuantType.W4A8 + + return QuantType.NONE def update_expert_map(self, new_expert_map): self._expert_map = new_expert_map diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index 8be021e71ad..519f861f170 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -29,21 +29,20 @@ from typing import Any +# Import all scheme modules to trigger registration via @register_scheme decorator +# Note: Add new quantization modules here to register them +from . import w4a4_flatquant # noqa: F401 +from . import w4a8 # noqa: F401 +from . import w4a16 # noqa: F401 +from . import w8a8_dynamic # noqa: F401 +from . import w8a8_mxfp8 # noqa: F401 +from . import w8a8_pdmix # noqa: F401 +from . import w8a8_static # noqa: F401 +from . import w8a16 # noqa: F401 # Import base classes -from .base import AscendLinearScheme, AscendMoEScheme +from .base import AscendLinearScheme, AscendMoEScheme, QuantType # Import registry functions from .registry import get_scheme_class, register_scheme -from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod -from .w4a8 import (AscendW4A8DynamicFusedMoEMethod, - AscendW4A8DynamicLinearMethod) -from .w4a16 import AscendW4A16FusedMoEMethod -from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, - AscendW8A8DynamicLinearMethod) -from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod -from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, - AscendW8A8PDMixLinearMethod) -from .w8a8_static import AscendW8A8LinearMethod -from .w8a16 import AscendW8A16LinearMethod def is_mx_quant_type(instance: Any) -> bool: @@ -55,6 +54,7 @@ def is_mx_quant_type(instance: Any) -> bool: Returns: True if the instance is an MX quantization type, False otherwise. """ + from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, ) return isinstance(instance, MX_QUANT_TYPES) @@ -63,28 +63,10 @@ def is_mx_quant_type(instance: Any) -> bool: # Base classes "AscendLinearScheme", "AscendMoEScheme", + "QuantType", # Registry functions "register_scheme", "get_scheme_class", # Utility functions "is_mx_quant_type", - # W8A8 static - "AscendW8A8LinearMethod", - # W8A8 dynamic - "AscendW8A8DynamicLinearMethod", - "AscendW8A8DynamicFusedMoEMethod", - # W8A8 MXFP8 - "AscendW8A8MXFP8DynamicLinearMethod", - # W8A8 PDMix - "AscendW8A8PDMixLinearMethod", - "AscendW8A8PDMixFusedMoeMethod", - # W8A16 - "AscendW8A16LinearMethod", - # W4A8 - "AscendW4A8DynamicLinearMethod", - "AscendW4A8DynamicFusedMoEMethod", - # W4A16 - "AscendW4A16FusedMoEMethod", - # W4A4 FlatQuant - "AscendW4A4FlatQuantDynamicLinearMethod", ] diff --git a/vllm_ascend/quantization/methods/base.py b/vllm_ascend/quantization/methods/base.py index ee277b927cc..e18ca84305a 100644 --- a/vllm_ascend/quantization/methods/base.py +++ b/vllm_ascend/quantization/methods/base.py @@ -17,11 +17,19 @@ """Abstract base classes for Ascend quantization schemes.""" from abc import ABC, abstractmethod +from enum import Enum from typing import Any, Callable, Dict, Optional import torch +class QuantType(Enum): + """Quantization type enum for MoE schemes.""" + NONE = 0 + W8A8 = 1 + W4A8 = 2 + + class AscendLinearScheme(ABC): """Base class for all linear quantization schemes. @@ -121,8 +129,15 @@ class AscendMoEScheme(ABC): Subclasses must implement get_weight(), get_dynamic_quant_param(), and apply() methods. + + Attributes: + quant_type: The quantization type for this scheme. Subclasses should + override this class attribute to declare their quant type. """ + # Default quant type - subclasses should override this + quant_type: QuantType = QuantType.NONE + @abstractmethod def get_weight(self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, diff --git a/vllm_ascend/quantization/methods/w4a8.py b/vllm_ascend/quantization/methods/w4a8.py index 04542db1386..3f7e642d867 100644 --- a/vllm_ascend/quantization/methods/w4a8.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -29,7 +29,7 @@ from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.utils import maybe_trans_nz -from .base import AscendLinearScheme, AscendMoEScheme +from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .registry import register_scheme @@ -202,6 +202,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): """FusedMoE method for Ascend W4A8_DYNAMIC.""" + # Declare the quantization type for this scheme + quant_type: QuantType = QuantType.W4A8 + def __init__(self): self.ep_group = get_ep_group() diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index 0bd770ee4a7..eb610f40ef3 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -32,7 +32,7 @@ zero_experts_compute) from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz -from .base import AscendLinearScheme, AscendMoEScheme +from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .registry import register_scheme @@ -108,6 +108,9 @@ def process_weights_after_loading(self, layer): class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): """FusedMoE method for Ascend W8A8_DYNAMIC.""" + # Declare the quantization type for this scheme + quant_type: QuantType = QuantType.W8A8 + def __init__(self): self.ep_group = get_ep_group() From 0e6071725839b1e232a1fb3fc51de0856ec066b3 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Sat, 10 Jan 2026 11:32:59 +0800 Subject: [PATCH 08/20] fix lit Signed-off-by: SlightwindSec --- vllm_ascend/quantization/methods/__init__.py | 44 +++++++++++--------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index 519f861f170..cf079738f91 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -29,32 +29,26 @@ from typing import Any -# Import all scheme modules to trigger registration via @register_scheme decorator -# Note: Add new quantization modules here to register them -from . import w4a4_flatquant # noqa: F401 -from . import w4a8 # noqa: F401 -from . import w4a16 # noqa: F401 -from . import w8a8_dynamic # noqa: F401 -from . import w8a8_mxfp8 # noqa: F401 -from . import w8a8_pdmix # noqa: F401 -from . import w8a8_static # noqa: F401 -from . import w8a16 # noqa: F401 # Import base classes from .base import AscendLinearScheme, AscendMoEScheme, QuantType # Import registry functions from .registry import get_scheme_class, register_scheme +# Import all scheme classes for external access +from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod +from .w4a8 import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) +from .w4a16 import AscendW4A16FusedMoEMethod +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) +from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod +from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod, + AscendW8A8PDMixLinearMethod) +from .w8a8_static import AscendW8A8LinearMethod +from .w8a16 import AscendW8A16LinearMethod def is_mx_quant_type(instance: Any) -> bool: - """Checks if the quantization method is a microscaling (MX) type. - - Args: - instance: The quantization method instance to check. - - Returns: - True if the instance is an MX quantization type, False otherwise. - """ - from .w8a8_mxfp8 import AscendW8A8MXFP8DynamicLinearMethod + """Checks if the quantization method is a microscaling (MX) type.""" MX_QUANT_TYPES = (AscendW8A8MXFP8DynamicLinearMethod, ) return isinstance(instance, MX_QUANT_TYPES) @@ -69,4 +63,16 @@ def is_mx_quant_type(instance: Any) -> bool: "get_scheme_class", # Utility functions "is_mx_quant_type", + # Scheme classes + "AscendW8A8LinearMethod", + "AscendW8A8DynamicLinearMethod", + "AscendW8A8DynamicFusedMoEMethod", + "AscendW8A8MXFP8DynamicLinearMethod", + "AscendW8A8PDMixLinearMethod", + "AscendW8A8PDMixFusedMoeMethod", + "AscendW8A16LinearMethod", + "AscendW4A8DynamicLinearMethod", + "AscendW4A8DynamicFusedMoEMethod", + "AscendW4A16FusedMoEMethod", + "AscendW4A4FlatQuantDynamicLinearMethod", ] From 03a1af540a227faac931bba3d774d72e296b822a Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Mon, 12 Jan 2026 11:54:59 +0800 Subject: [PATCH 09/20] add minmax_m2 Signed-off-by: SlightwindSec --- vllm_ascend/quantization/modelslim_config.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index f119b6419cc..060676734fa 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -167,6 +167,14 @@ ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] }, + "minimax_m2": { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "experts": ["experts.0.w1", "experts.0.w2", "experts.0.w3"] + } } @@ -341,6 +349,18 @@ def get_quant_method(self, layer: torch.nn.Module, vllm_config = get_current_vllm_config() model_type = vllm_config.model_config.hf_config.model_type + + if model_type in ["minimax", "minimax_m2"]: + # Adapt to Minimax architecture: update layer names to MoE convention + prefix = prefix.replace("mlp", "block_sparse_moe") + # Normalize the prefix by stripping specific expert indices (e.g., 'experts.0' -> 'experts') + parts = prefix.split('.') + if "experts" in parts and len(parts) > 2: + exp_idx = parts.index("experts") + if exp_idx + 1 < len(parts) and parts[exp_idx + 1].isdigit(): + parts = parts[:exp_idx + 1] + prefix = ".".join(parts) + if model_type in packed_modules_model_mapping: self.packed_modules_mapping = packed_modules_model_mapping[ model_type] From 134a5177e62c7661f3db732f3daae52689af457d Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Tue, 13 Jan 2026 11:22:18 +0800 Subject: [PATCH 10/20] remove get_quant_method Signed-off-by: SlightwindSec --- .../ut/quantization/test_modelslim_config.py | 10 +- .../quantization/compressed_tensors_config.py | 113 ++++++++-------- vllm_ascend/quantization/modelslim_config.py | 71 +++++++--- vllm_ascend/quantization/wrappers.py | 121 ++++++------------ 4 files changed, 151 insertions(+), 164 deletions(-) diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index 58d011b2498..cd3ac2686aa 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -90,21 +90,23 @@ def test_get_quant_method_for_linear(self): self.assertIsInstance(method, AscendUnquantizedLinearMethod) # Test quantized layer + mock_scheme = MagicMock() with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ + patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \ patch('vllm_ascend.quantization.wrappers.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: method = self.ascend_config.get_quant_method(linear_layer, ".attn") self.assertIs(method, mock_ascend_linear.return_value) - mock_ascend_linear.assert_called_once_with( - self.ascend_config, ".attn", - self.ascend_config.packed_modules_mapping, linear_layer) + mock_ascend_linear.assert_called_once_with(mock_scheme) def test_get_quant_method_for_attention(self): attention_layer = MagicMock(spec=Attention) mock_config = MagicMock() mock_config.model_config.hf_config.model_type = None + mock_scheme = MagicMock() with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ + patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \ patch('vllm_ascend.quantization.wrappers.AscendKVCacheMethod', \ return_value=MagicMock()) as mock_ascend_kvcache: # Test with fa_quant_type @@ -128,8 +130,10 @@ def test_get_quant_method_for_fused_moe(self): self.assertIs(method, mock_ascend_moe.return_value) # Test quantized layer + mock_scheme = MagicMock() with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ + patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \ patch('vllm_ascend.quantization.wrappers.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 18415131250..95724dbd44e 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -17,7 +17,7 @@ # """LLM-Compressor (compressed_tensors) quantization configuration for Ascend.""" -from typing import Any, Optional, cast +from typing import Any, Optional, Union, cast import torch from compressed_tensors.quantization import (QuantizationArgs, @@ -39,6 +39,8 @@ from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD +from .methods import AscendLinearScheme, AscendMoEScheme + logger = init_logger(__name__) @@ -54,21 +56,6 @@ def _remove_quantization_method(): "QuantizationArgs"]]] -def get_quant_method_llmcompressor(layer: torch.nn.Module): - """Get quantization method for LLM-Compressor models. - - Args: - layer: The layer module with a scheme attribute. - - Returns: - The scheme from the layer. - """ - logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") - if layer.scheme is None: - raise ValueError("A scheme must be defined for each layer") - return layer.scheme - - @register_quantization_config(COMPRESSED_TENSORS_METHOD) class AscendCompressedTensorsConfig(QuantizationConfig): """Config class for LLM-Compressor (compressed_tensors) quantization on Ascend. @@ -165,48 +152,49 @@ def get_quant_method( layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: - from .modelslim_config import AscendModelSlimConfig from .wrappers import AscendFusedMoEMethod, AscendLinearMethod if isinstance(layer, LinearBase): layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD - # collect schemes - quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) - - # choose quantization method - quant_method = UnquantizedLinearMethod() - if quant_scheme is not None: - layer.scheme = quant_scheme - ascend_quant_config = AscendModelSlimConfig( - self.quant_description or {}) - quant_method = AscendLinearMethod(ascend_quant_config, prefix, - None, layer) - return quant_method + # Get the scheme for this layer + scheme = self.get_scheme(layer=layer, layer_name=prefix) + + # Return unquantized method if no scheme found + if scheme is None: + return UnquantizedLinearMethod() + + # Store scheme on layer for reference (optional, for debugging) + layer.scheme = scheme + logger.info_once( + "Using the vLLM Ascend llmcompressor Quantization now!") + return AscendLinearMethod(scheme) + if isinstance(layer, FusedMoE): # Delayed import to avoid circular import from vllm_ascend.ops.fused_moe.fused_moe import \ AscendUnquantizedFusedMoEMethod layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD - # collect schemes - quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) - - # choose quantization method - quant_method = AscendUnquantizedFusedMoEMethod(layer.moe_config) - if quant_scheme is not None: - layer.scheme = quant_scheme - ascend_quant_config = AscendModelSlimConfig( - self.quant_description or {}) - quant_method = AscendFusedMoEMethod( - ascend_quant_config, prefix, - ascend_quant_config.packed_modules_mapping, layer) - return quant_method + # Get the scheme for this layer + scheme = self.get_scheme(layer=layer, layer_name=prefix) + + # Return unquantized method if no scheme found + if scheme is None: + return AscendUnquantizedFusedMoEMethod(layer.moe_config) + + # Store scheme on layer for reference (optional, for debugging) + layer.scheme = scheme + logger.info_once( + "Using the vLLM Ascend llmcompressor Quantization now!") + return AscendFusedMoEMethod(scheme, layer.moe_config) + return None - def get_scheme(self, - layer: torch.nn.Module, - layer_name: Optional[str] = None - ) -> Optional["CompressedTensorsScheme"]: + def get_scheme( + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None + ) -> Optional[Union[AscendLinearScheme, AscendMoEScheme]]: """Get the quantization scheme for a layer. compressed-tensors supports non uniform in the following way: @@ -218,7 +206,11 @@ def get_scheme(self, Detect whether a layer_name is found in any target and use the quantization scheme corresponding to the matched target - to select the CompressedTensorsScheme used for inference. + to select the appropriate Ascend scheme used for inference. + + Returns: + An Ascend quantization scheme instance, or None if the layer + should use unquantized method. """ # Find the "target" in the compressed-tensors config @@ -248,18 +240,25 @@ def get_scheme(self, "Falling back to UnquantizedLinearMethod") return None - else: - # Find the quant_scheme - scheme = self._get_scheme_from_parts( - weight_quant=weight_quant, - input_quant=input_quant, - ) - return scheme + # Find and return the appropriate Ascend scheme + return self._get_scheme_from_parts( + weight_quant=weight_quant, + input_quant=input_quant, + ) def _get_scheme_from_parts( - self, weight_quant: "QuantizationArgs", - input_quant: "QuantizationArgs") -> "CompressedTensorsScheme": - """Determine the appropriate scheme based on quantization args.""" + self, weight_quant: "QuantizationArgs", + input_quant: "QuantizationArgs" + ) -> Union[AscendLinearScheme, AscendMoEScheme]: + """Determine the appropriate Ascend scheme based on quantization args. + + Args: + weight_quant: Weight quantization arguments. + input_quant: Input activation quantization arguments. + + Returns: + An instance of the appropriate Ascend quantization scheme. + """ from .methods import (AscendW4A16FusedMoEMethod, AscendW8A8DynamicLinearMethod, AscendW8A8LinearMethod) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 060676734fa..f00cbb5c146 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -239,12 +239,12 @@ def get_linear_quant_type( return quant_type -def get_quant_method_modelslim( +def get_quant_type_for_layer( quant_description: Dict[str, Any], prefix: str, layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): - """Get quantization method for ModelSlim models. + packed_modules_mapping: Optional[Dict[str, Any]] = None) -> str: + """Determine the quantization type for a layer. Args: quant_description: The quantization description dictionary. @@ -253,23 +253,43 @@ def get_quant_method_modelslim( packed_modules_mapping: Mapping for packed/fused modules. Returns: - An instance of the appropriate quantization method class. + The quantization type string (e.g., "W8A8_DYNAMIC"). """ - logger.info_once("Using the vLLM Ascend modelslim Quantization now!") if packed_modules_mapping is None: packed_modules_mapping = dict() # Attention - if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): - quant_type = quant_description['fa_quant_type'] - # Linear - else: - quant_type = get_linear_quant_type(quant_description, prefix, - packed_modules_mapping) + if layer_type == "attention" and 'fa_quant_type' in quant_description.keys( + ): + return quant_description['fa_quant_type'] + # Linear / MoE + return get_linear_quant_type(quant_description, prefix, + packed_modules_mapping) + + +def create_scheme_for_layer( + quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + """Create a quantization scheme instance for a layer. + + Args: + quant_description: The quantization description dictionary. + prefix: The layer prefix. + layer_type: The type of layer ("linear", "moe", "attention"). + packed_modules_mapping: Mapping for packed/fused modules. + + Returns: + An instance of the appropriate quantization scheme class. + """ + logger.info_once("Using the vLLM Ascend modelslim Quantization now!") + quant_type = get_quant_type_for_layer(quant_description, prefix, layer_type, + packed_modules_mapping) # Use registry to get scheme class - method_cls = get_scheme_class(quant_type, layer_type) - if method_cls is not None: - return method_cls() + scheme_cls = get_scheme_class(quant_type, layer_type) + if scheme_cls is not None: + return scheme_cls() raise NotImplementedError( f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." @@ -375,12 +395,17 @@ def get_quant_method(self, layer: torch.nn.Module, from vllm_ascend.ops.linear import \ AscendUnquantizedLinearMethod return AscendUnquantizedLinearMethod() - return AscendLinearMethod(self, prefix, - self.packed_modules_mapping, layer) + scheme = create_scheme_for_layer(self.quant_description, prefix, + "linear", + self.packed_modules_mapping) + return AscendLinearMethod(scheme) elif isinstance(layer, Attention) and \ 'fa_quant_type' in self.quant_description.keys() and \ self.quant_description['fa_quant_type'] is not None: - return AscendKVCacheMethod(self, prefix) + scheme = create_scheme_for_layer(self.quant_description, prefix, + "attention", + self.packed_modules_mapping) + return AscendKVCacheMethod(scheme) elif isinstance(layer, FusedMoE): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): @@ -388,14 +413,18 @@ def get_quant_method(self, layer: torch.nn.Module, from vllm_ascend.ops.fused_moe.fused_moe import \ AscendUnquantizedFusedMoEMethod return AscendUnquantizedFusedMoEMethod(layer.moe_config) - return AscendFusedMoEMethod(self, prefix, - self.packed_modules_mapping, layer) + scheme = create_scheme_for_layer(self.quant_description, prefix, + "moe", + self.packed_modules_mapping) + return AscendFusedMoEMethod(scheme, layer.moe_config) elif isinstance(layer, VocabParallelEmbedding): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): return UnquantizedEmbeddingMethod() - return AscendEmbeddingMethod(self, prefix, - self.packed_modules_mapping, layer) + scheme = create_scheme_for_layer(self.quant_description, prefix, + "linear", + self.packed_modules_mapping) + return AscendEmbeddingMethod(scheme) return None def is_layer_skipped_ascend( diff --git a/vllm_ascend/quantization/wrappers.py b/vllm_ascend/quantization/wrappers.py index abf93fee727..91953479d3b 100644 --- a/vllm_ascend/quantization/wrappers.py +++ b/vllm_ascend/quantization/wrappers.py @@ -15,18 +15,28 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -"""Wrapper classes that delegate to actual quantization scheme implementations.""" +"""Wrapper classes that delegate to actual quantization scheme implementations. -from typing import Any, Callable, Dict, List, Optional +These wrapper classes (AscendLinearMethod, AscendFusedMoEMethod, etc.) implement +the vLLM QuantizeMethodBase interface and delegate the actual quantization +operations to scheme implementations (AscendLinearScheme, AscendMoEScheme). + +The wrapper classes handle: +- Weight creation and registration +- Parameter attribute setting +- Tensor parallel rank handling +- Delegation to the underlying scheme's apply() method +""" + +from typing import Callable, List, Optional, Union import torch from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.linear import (LinearMethodBase, RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import \ - QuantizeMethodBase from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs @@ -35,64 +45,24 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, flashcomm2_enable, - mlp_tp_enable, oproj_tp_enable) +from vllm_ascend.utils import (flashcomm2_enable, mlp_tp_enable, + oproj_tp_enable) -from .methods import is_mx_quant_type - - -def get_quant_method(quant_description: Dict[str, Any], - prefix: str, - layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None, - layer: Optional[torch.nn.Module] = None): - """Get the appropriate quantization method for a layer. - - This is the routing function that dispatches to either ModelSlim or - LLM-Compressor implementations based on the quant_description. - - Args: - quant_description: The quantization description dictionary. - prefix: The layer prefix. - layer_type: The type of layer ("linear", "moe", "attention"). - packed_modules_mapping: Mapping for packed/fused modules. - layer: The layer module (optional). - - Returns: - An instance of the appropriate quantization method class. - """ - if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD: - from .compressed_tensors_config import get_quant_method_llmcompressor - return get_quant_method_llmcompressor(layer) - - from .modelslim_config import get_quant_method_modelslim - return get_quant_method_modelslim(quant_description, prefix, layer_type, - packed_modules_mapping) +from .methods import AscendLinearScheme, AscendMoEScheme, is_mx_quant_type class AscendLinearMethod(LinearMethodBase): """Linear method for Ascend quantization. - This wrapper class delegates to the actual quantization scheme implementation - based on the quant_config and prefix. + This wrapper class delegates to the actual quantization scheme implementation. + The scheme is determined by the Config class and passed directly to this wrapper. Args: - quant_config: The Ascend quantization config. - prefix: The layer prefix for determining quantization type. - packed_modules_mapping: Mapping for packed/fused modules. - layer: The layer module (optional). + scheme: The quantization scheme instance (e.g., AscendW8A8DynamicLinearMethod). """ - def __init__(self, - quant_config: "QuantizeMethodBase", - prefix: str, - packed_modules_mapping: Optional[Dict[str, Any]], - layer: Optional[torch.nn.Module] = None) -> None: - self.quant_method = get_quant_method(quant_config.quant_description, - prefix, - "linear", - packed_modules_mapping, - layer=layer) + def __init__(self, scheme: AscendLinearScheme) -> None: + self.quant_method = scheme def create_weights( self, @@ -199,15 +169,14 @@ def apply( class AscendKVCacheMethod(BaseKVCacheMethod): """KVCache method for Ascend quantization. + This wrapper class delegates to the actual attention quantization scheme. + Args: - quant_config: The Ascend quantization config. - prefix: The layer prefix. + scheme: The attention quantization scheme instance. """ - def __init__(self, quant_config: "QuantizeMethodBase", - prefix: str) -> None: - self.quant_method = get_quant_method(quant_config.quant_description, - prefix, "attention") + def __init__(self, scheme: AscendLinearScheme) -> None: + self.quant_method = scheme def create_weights(self, layer: torch.nn.Module) -> None: # Different from linear method, there are no weight processing/slicing @@ -229,22 +198,17 @@ def apply(self, layer: torch.nn.Module, query: torch.Tensor, class AscendFusedMoEMethod(FusedMoEMethodBase): """FusedMoE method for Ascend quantization. + This wrapper class delegates to the actual MoE quantization scheme. + Args: - quant_config: The Ascend quantization config. - prefix: The layer prefix. - packed_modules_mapping: Mapping for packed/fused modules. - layer: The layer module. + scheme: The MoE quantization scheme instance. + moe_config: The FusedMoE configuration. """ - def __init__(self, quant_config: "QuantizeMethodBase", prefix: str, - packed_modules_mapping: Optional[Dict[str, Any]], - layer: torch.nn.Module): - super().__init__(layer.moe_config) - self.quant_method = get_quant_method(quant_config.quant_description, - prefix, - "moe", - packed_modules_mapping, - layer=layer) + def __init__(self, scheme: AscendMoEScheme, + moe_config: FusedMoEConfig) -> None: + super().__init__(moe_config) + self.quant_method = scheme def create_weights( self, @@ -331,17 +295,8 @@ class AscendEmbeddingMethod(AscendLinearMethod): for clarity when used with VocabParallelEmbedding layers. Args: - quant_config: The Ascend quantization config. - prefix: The layer prefix. - packed_modules_mapping: Mapping for packed/fused modules. - layer: The layer module. + scheme: The quantization scheme instance. """ - def __init__(self, quant_config: "QuantizeMethodBase", prefix: str, - packed_modules_mapping: Optional[Dict[str, Any]], - layer: torch.nn.Module) -> None: - self.quant_method = get_quant_method(quant_config.quant_description, - prefix, - "linear", - packed_modules_mapping, - layer=layer) + def __init__(self, scheme: AscendLinearScheme) -> None: + self.quant_method = scheme From 95c23a5e6637ac8e96cb4283a375d43e47baf597 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Tue, 13 Jan 2026 14:31:35 +0800 Subject: [PATCH 11/20] fix lint Signed-off-by: SlightwindSec --- vllm_ascend/quantization/compressed_tensors_config.py | 5 +---- vllm_ascend/quantization/modelslim_config.py | 4 ++-- vllm_ascend/quantization/wrappers.py | 5 ++--- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 95724dbd44e..ff59f7cf3b5 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -30,8 +30,6 @@ QUANTIZATION_METHODS, register_quantization_config) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import \ - CompressedTensorsScheme from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -247,8 +245,7 @@ def get_scheme( ) def _get_scheme_from_parts( - self, weight_quant: "QuantizationArgs", - input_quant: "QuantizationArgs" + self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs" ) -> Union[AscendLinearScheme, AscendMoEScheme]: """Determine the appropriate Ascend scheme based on quantization args. diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index f00cbb5c146..e7add84ec6f 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -283,8 +283,8 @@ def create_scheme_for_layer( An instance of the appropriate quantization scheme class. """ logger.info_once("Using the vLLM Ascend modelslim Quantization now!") - quant_type = get_quant_type_for_layer(quant_description, prefix, layer_type, - packed_modules_mapping) + quant_type = get_quant_type_for_layer(quant_description, prefix, + layer_type, packed_modules_mapping) # Use registry to get scheme class scheme_cls = get_scheme_class(quant_type, layer_type) diff --git a/vllm_ascend/quantization/wrappers.py b/vllm_ascend/quantization/wrappers.py index 91953479d3b..67a533ce99e 100644 --- a/vllm_ascend/quantization/wrappers.py +++ b/vllm_ascend/quantization/wrappers.py @@ -28,7 +28,7 @@ - Delegation to the underlying scheme's apply() method """ -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional import torch from vllm.distributed import get_tensor_model_parallel_rank @@ -45,8 +45,7 @@ from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (flashcomm2_enable, mlp_tp_enable, - oproj_tp_enable) +from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable from .methods import AscendLinearScheme, AscendMoEScheme, is_mx_quant_type From 3f615a35514c687bb8548b0c1748f39307f44350 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Tue, 13 Jan 2026 15:55:38 +0800 Subject: [PATCH 12/20] fix mypy type hint Signed-off-by: SlightwindSec --- .../quantization/compressed_tensors_config.py | 51 +++++++++++++++---- vllm_ascend/quantization/methods/__init__.py | 4 +- vllm_ascend/quantization/methods/base.py | 46 +++++++++++++++++ vllm_ascend/quantization/modelslim_config.py | 7 ++- vllm_ascend/quantization/wrappers.py | 8 +-- 5 files changed, 101 insertions(+), 15 deletions(-) diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index ff59f7cf3b5..f1f13757e01 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -155,17 +155,18 @@ def get_quant_method( if isinstance(layer, LinearBase): layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD # Get the scheme for this layer - scheme = self.get_scheme(layer=layer, layer_name=prefix) + linear_scheme = self._get_linear_scheme(layer=layer, + layer_name=prefix) # Return unquantized method if no scheme found - if scheme is None: + if linear_scheme is None: return UnquantizedLinearMethod() # Store scheme on layer for reference (optional, for debugging) - layer.scheme = scheme + layer.scheme = linear_scheme logger.info_once( "Using the vLLM Ascend llmcompressor Quantization now!") - return AscendLinearMethod(scheme) + return AscendLinearMethod(linear_scheme) if isinstance(layer, FusedMoE): # Delayed import to avoid circular import @@ -174,25 +175,57 @@ def get_quant_method( layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD # Get the scheme for this layer - scheme = self.get_scheme(layer=layer, layer_name=prefix) + moe_scheme = self._get_moe_scheme(layer=layer, layer_name=prefix) # Return unquantized method if no scheme found - if scheme is None: + if moe_scheme is None: return AscendUnquantizedFusedMoEMethod(layer.moe_config) # Store scheme on layer for reference (optional, for debugging) - layer.scheme = scheme + layer.scheme = moe_scheme logger.info_once( "Using the vLLM Ascend llmcompressor Quantization now!") - return AscendFusedMoEMethod(scheme, layer.moe_config) + return AscendFusedMoEMethod(moe_scheme, layer.moe_config) return None + def _get_linear_scheme( + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None) -> Optional[AscendLinearScheme]: + """Get the linear quantization scheme for a layer. + + Returns: + An AscendLinearScheme instance, or None if the layer + should use unquantized method. + """ + scheme = self.get_scheme(layer=layer, layer_name=layer_name) + if scheme is None: + return None + # The scheme should be AscendLinearScheme for linear layers + return cast(AscendLinearScheme, scheme) + + def _get_moe_scheme( + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None) -> Optional[AscendMoEScheme]: + """Get the MoE quantization scheme for a layer. + + Returns: + An AscendMoEScheme instance, or None if the layer + should use unquantized method. + """ + scheme = self.get_scheme(layer=layer, layer_name=layer_name) + if scheme is None: + return None + # The scheme should be AscendMoEScheme for MoE layers + return cast(AscendMoEScheme, scheme) + def get_scheme( self, layer: torch.nn.Module, layer_name: Optional[str] = None - ) -> Optional[Union[AscendLinearScheme, AscendMoEScheme]]: + ) -> Optional["AscendLinearScheme | AscendMoEScheme"]: """Get the quantization scheme for a layer. compressed-tensors supports non uniform in the following way: diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index cf079738f91..b643dd86c28 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -30,7 +30,8 @@ from typing import Any # Import base classes -from .base import AscendLinearScheme, AscendMoEScheme, QuantType +from .base import (AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, + QuantType) # Import registry functions from .registry import get_scheme_class, register_scheme # Import all scheme classes for external access @@ -55,6 +56,7 @@ def is_mx_quant_type(instance: Any) -> bool: __all__ = [ # Base classes + "AscendAttentionScheme", "AscendLinearScheme", "AscendMoEScheme", "QuantType", diff --git a/vllm_ascend/quantization/methods/base.py b/vllm_ascend/quantization/methods/base.py index e18ca84305a..9bcec5c2bb8 100644 --- a/vllm_ascend/quantization/methods/base.py +++ b/vllm_ascend/quantization/methods/base.py @@ -124,6 +124,52 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass +class AscendAttentionScheme(ABC): + """Base class for all attention quantization schemes. + + Subclasses must implement apply() method. + Other methods have default implementations. + """ + + def create_weights(self, layer: torch.nn.Module) -> None: + """Create weights for attention quantization. + + Args: + layer: The attention layer module. + """ + pass + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Post-loading weight processing for attention layer. + + Args: + layer: The attention layer module. + """ + pass + + @abstractmethod + def apply(self, layer: torch.nn.Module, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, + attn_type, scale, output) -> torch.Tensor: + """Forward computation for attention layer. + + Args: + layer: The attention layer module. + query: Query tensor. + key: Key tensor. + value: Value tensor. + kv_cache: KV cache. + attn_metadata: Attention metadata. + attn_type: Attention type. + scale: Scale factor. + output: Output tensor. + + Returns: + Output tensor after attention computation. + """ + ... + + class AscendMoEScheme(ABC): """Base class for all MoE quantization schemes. diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index e7add84ec6f..a636a444728 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -243,7 +243,8 @@ def get_quant_type_for_layer( quant_description: Dict[str, Any], prefix: str, layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None) -> str: + packed_modules_mapping: Optional[Dict[str, + Any]] = None) -> Optional[str]: """Determine the quantization type for a layer. Args: @@ -286,6 +287,10 @@ def create_scheme_for_layer( quant_type = get_quant_type_for_layer(quant_description, prefix, layer_type, packed_modules_mapping) + if quant_type is None: + raise ValueError( + f"Could not determine quantization type for layer {prefix}.") + # Use registry to get scheme class scheme_cls = get_scheme_class(quant_type, layer_type) if scheme_cls is not None: diff --git a/vllm_ascend/quantization/wrappers.py b/vllm_ascend/quantization/wrappers.py index 67a533ce99e..e3d6293cf6d 100644 --- a/vllm_ascend/quantization/wrappers.py +++ b/vllm_ascend/quantization/wrappers.py @@ -47,7 +47,8 @@ get_otp_group) from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable -from .methods import AscendLinearScheme, AscendMoEScheme, is_mx_quant_type +from .methods import (AscendAttentionScheme, AscendLinearScheme, + AscendMoEScheme, is_mx_quant_type) class AscendLinearMethod(LinearMethodBase): @@ -174,7 +175,7 @@ class AscendKVCacheMethod(BaseKVCacheMethod): scheme: The attention quantization scheme instance. """ - def __init__(self, scheme: AscendLinearScheme) -> None: + def __init__(self, scheme: AscendAttentionScheme) -> None: self.quant_method = scheme def create_weights(self, layer: torch.nn.Module) -> None: @@ -184,8 +185,7 @@ def create_weights(self, layer: torch.nn.Module) -> None: self.quant_method.create_weights(layer) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if hasattr(self.quant_method, "process_weights_after_loading"): - self.quant_method.process_weights_after_loading(layer) + self.quant_method.process_weights_after_loading(layer) def apply(self, layer: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, From 95f4271814690a111d97e16c2839f0b19fa17c8c Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Wed, 14 Jan 2026 16:51:01 +0800 Subject: [PATCH 13/20] decouple linear and moe in compressed_tensors_config Signed-off-by: SlightwindSec --- .../quantization/compressed_tensors_config.py | 101 ++++++++++++------ 1 file changed, 67 insertions(+), 34 deletions(-) diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 83bf493e6a8..49121acc39e 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -212,11 +212,17 @@ def _get_linear_scheme( An AscendLinearScheme instance, or None if the layer should use unquantized method. """ - scheme = self.get_scheme(layer=layer, layer_name=layer_name) - if scheme is None: + weight_quant, input_quant, format = self._get_quant_args( + layer, layer_name) + if weight_quant is None: return None - # The scheme should be AscendLinearScheme for linear layers - return cast(AscendLinearScheme, scheme) + + return self._create_scheme_for_layer_type( + weight_quant=weight_quant, + input_quant=input_quant, + format=format, + layer_type="linear", + ) def _get_moe_scheme( self, @@ -231,18 +237,25 @@ def _get_moe_scheme( # Add FusedMoE to target scheme map if needed self._add_fused_moe_to_target_scheme_map() - scheme = self.get_scheme(layer=layer, layer_name=layer_name) - if scheme is None: + weight_quant, input_quant, format = self._get_quant_args( + layer, layer_name) + if weight_quant is None: return None - # The scheme should be AscendMoEScheme for MoE layers - return cast(AscendMoEScheme, scheme) - def get_scheme( + return self._create_scheme_for_layer_type( + weight_quant=weight_quant, + input_quant=input_quant, + format=format, + layer_type="moe", + ) + + def _get_quant_args( self, layer: torch.nn.Module, layer_name: Optional[str] = None - ) -> Optional["AscendLinearScheme | AscendMoEScheme"]: - """Get the quantization scheme for a layer. + ) -> tuple[Optional["QuantizationArgs"], Optional["QuantizationArgs"], + Optional[str]]: + """Extract quantization arguments for a layer. compressed-tensors supports non uniform in the following way: @@ -252,14 +265,12 @@ def get_scheme( an nn.Module name. Detect whether a layer_name is found in any target and - use the quantization scheme corresponding to the matched target - to select the appropriate Ascend scheme used for inference. + use the quantization scheme corresponding to the matched target. Returns: - An Ascend quantization scheme instance, or None if the layer - should use unquantized method. + A tuple of (weight_quant, input_quant, format). weight_quant is + None if the layer should use unquantized method. """ - scheme_dict = self.get_scheme_dict(layer, layer_name) weight_quant = None input_quant = None @@ -273,14 +284,8 @@ def get_scheme( logger.warning_once("Acceleration for non-quantized schemes is " "not supported by Compressed Tensors. " "Falling back to UnquantizedLinearMethod") - return None - # Find and return the appropriate Ascend scheme - return self._get_scheme_from_parts( - weight_quant=weight_quant, - input_quant=input_quant, - format=format, - ) + return weight_quant, input_quant, format def get_scheme_dict( self, @@ -316,42 +321,70 @@ def get_scheme_dict( return None - def _get_scheme_from_parts( + def _create_scheme_for_layer_type( self, weight_quant: "QuantizationArgs", - input_quant: "QuantizationArgs", - format: str | None = None, + input_quant: Optional["QuantizationArgs"], + format: Optional[str], + layer_type: str, ) -> Union[AscendLinearScheme, AscendMoEScheme]: - """Determine the appropriate Ascend scheme based on quantization args. + """Create the appropriate Ascend scheme based on quantization args and layer type. Args: weight_quant: Weight quantization arguments. input_quant: Input activation quantization arguments. format: Per-layer format, if defined. + layer_type: Type of layer ("linear" or "moe"). Returns: An instance of the appropriate Ascend quantization scheme. """ - from .methods import (AscendW4A16FusedMoEMethod, - AscendW8A8DynamicLinearMethod, - AscendW8A8LinearMethod) + from .methods import get_scheme_class + + # Determine the quantization type + quant_type = self._detect_quant_type(weight_quant, input_quant, format) + # Get the scheme class from registry + scheme_cls = get_scheme_class(quant_type, layer_type) + if scheme_cls is None: + raise NotImplementedError( + f"No compressed-tensors compatible scheme was found for " + f"quant_type={quant_type}, layer_type={layer_type}.") + + return scheme_cls() + + def _detect_quant_type( + self, + weight_quant: "QuantizationArgs", + input_quant: Optional["QuantizationArgs"], + format: Optional[str], + ) -> str: + """Detect the quantization type from quantization arguments. + + Args: + weight_quant: Weight quantization arguments. + input_quant: Input activation quantization arguments. + format: Per-layer format, if defined. + + Returns: + A string representing the quantization type (e.g., "W8A8", "W8A8_DYNAMIC"). + """ # use the per-layer format if defined, otherwise, use global format format = format if format is not None else self.quant_format act_quant_format = is_activation_quantization_format(format) if act_quant_format and input_quant is not None: if self._is_static_tensor_w8a8(weight_quant, input_quant): - return AscendW8A8LinearMethod() + return "W8A8" if self._is_dynamic_token_w8a8(weight_quant, input_quant): - return AscendW8A8DynamicLinearMethod() + return "W8A8_DYNAMIC" if self._is_w4a16(weight_quant, input_quant): - return AscendW4A16FusedMoEMethod() + return "W4A16" raise NotImplementedError( - "No compressed-tensors compatible scheme was found.") + "No compressed-tensors compatible quantization type was found.") def _is_static_tensor_w8a8(self, weight_quant: "QuantizationArgs", input_quant: "QuantizationArgs") -> bool: From e9f8c39d9d70c1280421c5ae2c261b21f9412921 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Wed, 14 Jan 2026 17:22:42 +0800 Subject: [PATCH 14/20] rename wrappers.py -> method_adapters.py Signed-off-by: SlightwindSec --- tests/ut/ops/test_fused_moe.py | 21 -- vllm_ascend/ops/linear_op.py | 2 +- .../quantization/compressed_tensors_config.py | 2 +- vllm_ascend/quantization/method_adapters.py | 297 ++++++++++++++++++ vllm_ascend/quantization/modelslim_config.py | 2 +- vllm_ascend/quantization/wrappers.py | 4 - 6 files changed, 300 insertions(+), 28 deletions(-) create mode 100644 vllm_ascend/quantization/method_adapters.py diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index 2d519b92a9d..ffb85a197be 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -20,8 +20,6 @@ import torch.nn as nn import torch_npu from pytest_mock import MockerFixture -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase - from tests.ut.base import TestBase from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.experts_selector import select_experts @@ -233,25 +231,6 @@ def __init__(self, shared_experts, num_tokens): self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32))) -class MockFusedMoEMethod(FusedMoEMethodBase): - moe = MagicMock() - - def __init__(self): - super().__init__(self.moe) - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - pass - - def apply(self, hidden_states: torch.Tensor, - expert_weights: torch.Tensor) -> torch.Tensor: - pass - - def get_fused_moe_quant_config(self, layer: torch.nn.Module): - pass - - class TestExpertsSelector: @pytest.mark.parametrize("global_num_experts", [256, 128]) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 78c09a2f099..8c6afffa3a7 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -585,7 +585,7 @@ def matmul_and_reduce(self, input_parallel: torch.Tensor, from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm_ascend.quantization.methods import AscendW8A8LinearMethod - from vllm_ascend.quantization.wrappers import AscendLinearMethod + from vllm_ascend.quantization.method_adapters import AscendLinearMethod # For unquant if mmrs_fusion and isinstance(self.layer.quant_method, diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 49121acc39e..8aa1af1fc50 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -163,7 +163,7 @@ def get_quant_method( layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: - from .wrappers import AscendFusedMoEMethod, AscendLinearMethod + from .method_adapters import AscendFusedMoEMethod, AscendLinearMethod if isinstance(layer, LinearBase): layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py new file mode 100644 index 00000000000..d46c2f71695 --- /dev/null +++ b/vllm_ascend/quantization/method_adapters.py @@ -0,0 +1,297 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +"""Wrapper classes that delegate to actual quantization scheme implementations. + +These wrapper classes (AscendLinearMethod, AscendFusedMoEMethod, etc.) implement +the vLLM QuantizeMethodBase interface and delegate the actual quantization +operations to scheme implementations (AscendLinearScheme, AscendMoEScheme). + +The wrapper classes handle: +- Weight creation and registration +- Parameter attribute setting +- Tensor parallel rank handling +- Delegation to the underlying scheme's apply() method +""" + +from typing import Callable, List, Optional + +import torch +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.linear import (LinearMethodBase, + RowParallelLinear) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.parameter import PerTensorScaleParameter +from vllm.model_executor.utils import set_weight_attrs + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, + get_mlp_tp_group, + get_otp_group) +from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable + +from .methods import (AscendAttentionScheme, AscendLinearScheme, + AscendMoEScheme, is_mx_quant_type) + + +class AscendLinearMethod(LinearMethodBase): + """Linear method for Ascend quantization. + + This wrapper class delegates to the actual quantization scheme implementation. + The scheme is determined by the Config class and passed directly to this wrapper. + + Args: + scheme: The quantization scheme instance (e.g., AscendW8A8DynamicLinearMethod). + """ + + def __init__(self, scheme: AscendLinearScheme) -> None: + self.quant_method = scheme + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + weight_dict = self.quant_method.get_weight(input_size_per_partition, + output_size_per_partition, + params_dtype) + + # Extract packing information (if present) + packed_dim = weight_dict.pop("_packed_dim", None) + packed_factor = weight_dict.pop("_packed_factor", None) + + for weight_name, weight_param in weight_dict.items(): + param = torch.nn.Parameter(weight_param, requires_grad=False) + set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) + + # Set packing attributes if the weight is packed + if packed_dim is not None and packed_factor is not None: + set_weight_attrs(param, { + "packed_dim": packed_dim, + "packed_factor": packed_factor + }) + + layer.register_parameter(weight_name, param) + set_weight_attrs(param, extra_weight_attrs) + + pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) + for pertensor_name, pertensor_param in pertensor_dict.items(): + param = PerTensorScaleParameter(data=pertensor_param, + weight_loader=weight_loader) + # disable warning + param.ignore_warning = True + layer.register_parameter(pertensor_name, param) + param.weight_loader = extra_weight_attrs.get("weight_loader") + + perchannel_dict = self.quant_method.get_perchannel_param( + output_size_per_partition, params_dtype) + for perchannel_name, perchannel_param in perchannel_dict.items(): + param = torch.nn.Parameter(perchannel_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(perchannel_name, param) + set_weight_attrs(param, extra_weight_attrs) + + # NOTE: In w4a8 quantization implementation, + # for down_proj and o_proj scale_bias shape is [output_size, 16], + # others are [output_size, 1] + layer_type = "row" if isinstance(layer, + RowParallelLinear) else "others" + + pergroup_dict = self.quant_method.get_pergroup_param( + input_size_per_partition, + output_size_per_partition, + params_dtype, + layer_type=layer_type) + for pergroup_name, pergroup_param in pergroup_dict.items(): + param = torch.nn.Parameter(pergroup_param, requires_grad=False) + set_weight_attrs(param, {"output_dim": 0}) + layer.register_parameter(pergroup_name, param) + set_weight_attrs(param, extra_weight_attrs) + if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \ + or is_mx_quant_type(self.quant_method): + setattr(param, "input_dim", 1) + param.input_dim = 1 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(layer, RowParallelLinear): + if layer.prefix.find("o_proj") != -1 and oproj_tp_enable(): + tp_rank = get_otp_group().rank_in_group + elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): + tp_rank = get_mlp_tp_group().rank_in_group + elif (layer.prefix.find("o_proj") != -1 or + layer.prefix.find("out_proj") != -1) and flashcomm2_enable(): + if get_ascend_config( + ).flashcomm2_oproj_tensor_parallel_size == 1: + tp_rank = 0 + else: + tp_rank = get_flashcomm2_otp_group().rank_in_group + else: + tp_rank = get_tensor_model_parallel_rank() + else: + tp_rank = 0 + return self.quant_method.apply(layer, x, bias, tp_rank) + + +class AscendKVCacheMethod(BaseKVCacheMethod): + """KVCache method for Ascend quantization. + + This wrapper class delegates to the actual attention quantization scheme. + + Args: + scheme: The attention quantization scheme instance. + """ + + def __init__(self, scheme: AscendAttentionScheme) -> None: + self.quant_method = scheme + + def create_weights(self, layer: torch.nn.Module) -> None: + # Different from linear method, there are no weight processing/slicing + # steps for attention in vllm. So the whole process of create weights + # is hidden into the specific quant method. + self.quant_method.create_weights(layer) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.quant_method.process_weights_after_loading(layer) + + def apply(self, layer: torch.nn.Module, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, + attn_type, scale, output) -> torch.Tensor: + return self.quant_method.apply(layer, query, key, value, kv_cache, + attn_metadata, attn_type, scale, output) + + +class AscendFusedMoEMethod(FusedMoEMethodBase): + """FusedMoE method for Ascend quantization. + + This wrapper class delegates to the actual MoE quantization scheme. + + Args: + scheme: The MoE quantization scheme instance. + moe_config: The FusedMoE configuration. + """ + + def __init__(self, scheme: AscendMoEScheme, + moe_config: FusedMoEConfig) -> None: + super().__init__(moe_config) + self.quant_method = scheme + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + weight_param = self.quant_method.get_weight( + num_experts, intermediate_size_per_partition, hidden_size, + params_dtype) + for param_key, param_value in weight_param.items(): + param = torch.nn.Parameter(param_value, requires_grad=False) + layer.register_parameter(param_key, param) + set_weight_attrs(param, extra_weight_attrs) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + per_group_param = [ + "weight_scale_second", "weight_offset_second", "scale_bias" + ] + ["weight_scale", "weight_offset"] if hasattr( + self.quant_method, + "group_size") and self.quant_method.group_size > 0 else [] + dynamic_quant_param = self.quant_method.get_dynamic_quant_param( + num_experts, intermediate_size_per_partition, hidden_size, + params_dtype) + for param_key, param_value in dynamic_quant_param.items(): + param = torch.nn.Parameter(param_value, requires_grad=False) + layer.register_parameter(param_key, param) + set_weight_attrs(param, extra_weight_attrs) + if any(fields in param_key for fields in per_group_param): + setattr(param, "quant_method", + FusedMoeWeightScaleSupported.GROUP.value) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = False, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num=0, + **kwargs, + ) -> torch.Tensor: + return self.quant_method.apply( + layer, x, router_logits, top_k, renormalize, use_grouped_topk, + global_num_experts, expert_map, topk_group, num_expert_group, + custom_routing_function, scoring_func, routed_scaling_factor, + e_score_correction_bias, is_prefill, enable_force_load_balance, + log2phy, global_redundant_expert_num, **kwargs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + @property + def supports_eplb(self): + supports_eplb = getattr(self.quant_method, "supports_eplb", False) + return supports_eplb + + +class AscendEmbeddingMethod(AscendLinearMethod): + """Embedding method for Ascend quantization. + + This is essentially the same as AscendLinearMethod, just with a different name + for clarity when used with VocabParallelEmbedding layers. + + Args: + scheme: The quantization scheme instance. + """ + + def __init__(self, scheme: AscendLinearScheme) -> None: + self.quant_method = scheme diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index a636a444728..2f2f26942ef 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -369,7 +369,7 @@ def quant_prefix_mapper(self, model_type: str, prefix: str) -> str: def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: - from .wrappers import (AscendEmbeddingMethod, AscendFusedMoEMethod, + from .method_adapters import (AscendEmbeddingMethod, AscendFusedMoEMethod, AscendKVCacheMethod, AscendLinearMethod) vllm_config = get_current_vllm_config() diff --git a/vllm_ascend/quantization/wrappers.py b/vllm_ascend/quantization/wrappers.py index e3d6293cf6d..d46c2f71695 100644 --- a/vllm_ascend/quantization/wrappers.py +++ b/vllm_ascend/quantization/wrappers.py @@ -277,10 +277,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) - def get_fused_moe_quant_config(self, layer: torch.nn.Module): - # TODO: implement this function - pass - @property def supports_eplb(self): supports_eplb = getattr(self.quant_method, "supports_eplb", False) From 29b8194cf299b39a69a444ad1e1b514131cd0186 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Wed, 14 Jan 2026 18:04:37 +0800 Subject: [PATCH 15/20] remove wrappers.py Signed-off-by: SlightwindSec --- .../ut/quantization/test_modelslim_config.py | 6 +- vllm_ascend/quantization/method_adapters.py | 3 + vllm_ascend/quantization/wrappers.py | 297 ------------------ 3 files changed, 6 insertions(+), 300 deletions(-) delete mode 100644 vllm_ascend/quantization/wrappers.py diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index cd3ac2686aa..667a7c0d8e5 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -94,7 +94,7 @@ def test_get_quant_method_for_linear(self): with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \ - patch('vllm_ascend.quantization.wrappers.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: + patch('vllm_ascend.quantization.method_adapters.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear: method = self.ascend_config.get_quant_method(linear_layer, ".attn") self.assertIs(method, mock_ascend_linear.return_value) @@ -107,7 +107,7 @@ def test_get_quant_method_for_attention(self): mock_scheme = MagicMock() with patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \ - patch('vllm_ascend.quantization.wrappers.AscendKVCacheMethod', \ + patch('vllm_ascend.quantization.method_adapters.AscendKVCacheMethod', \ return_value=MagicMock()) as mock_ascend_kvcache: # Test with fa_quant_type method = self.ascend_config.get_quant_method( @@ -134,7 +134,7 @@ def test_get_quant_method_for_fused_moe(self): with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ patch("vllm_ascend.quantization.modelslim_config.get_current_vllm_config", return_value=mock_config), \ patch("vllm_ascend.quantization.modelslim_config.create_scheme_for_layer", return_value=mock_scheme), \ - patch('vllm_ascend.quantization.wrappers.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: + patch('vllm_ascend.quantization.method_adapters.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe: method = self.ascend_config.get_quant_method( fused_moe_layer, "moe_layer") self.assertIs(method, mock_ascend_moe.return_value) diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py index d46c2f71695..beaf78a99c5 100644 --- a/vllm_ascend/quantization/method_adapters.py +++ b/vllm_ascend/quantization/method_adapters.py @@ -277,6 +277,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) + def get_fused_moe_quant_config(self, layer: torch.nn.Module): + pass + @property def supports_eplb(self): supports_eplb = getattr(self.quant_method, "supports_eplb", False) diff --git a/vllm_ascend/quantization/wrappers.py b/vllm_ascend/quantization/wrappers.py deleted file mode 100644 index d46c2f71695..00000000000 --- a/vllm_ascend/quantization/wrappers.py +++ /dev/null @@ -1,297 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -"""Wrapper classes that delegate to actual quantization scheme implementations. - -These wrapper classes (AscendLinearMethod, AscendFusedMoEMethod, etc.) implement -the vLLM QuantizeMethodBase interface and delegate the actual quantization -operations to scheme implementations (AscendLinearScheme, AscendMoEScheme). - -The wrapper classes handle: -- Weight creation and registration -- Parameter attribute setting -- Tensor parallel rank handling -- Delegation to the underlying scheme's apply() method -""" - -from typing import Callable, List, Optional - -import torch -from vllm.distributed import get_tensor_model_parallel_rank -from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase, - FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig -from vllm.model_executor.layers.linear import (LinearMethodBase, - RowParallelLinear) -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.parameter import PerTensorScaleParameter -from vllm.model_executor.utils import set_weight_attrs - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.distributed.parallel_state import (get_flashcomm2_otp_group, - get_mlp_tp_group, - get_otp_group) -from vllm_ascend.utils import flashcomm2_enable, mlp_tp_enable, oproj_tp_enable - -from .methods import (AscendAttentionScheme, AscendLinearScheme, - AscendMoEScheme, is_mx_quant_type) - - -class AscendLinearMethod(LinearMethodBase): - """Linear method for Ascend quantization. - - This wrapper class delegates to the actual quantization scheme implementation. - The scheme is determined by the Config class and passed directly to this wrapper. - - Args: - scheme: The quantization scheme instance (e.g., AscendW8A8DynamicLinearMethod). - """ - - def __init__(self, scheme: AscendLinearScheme) -> None: - self.quant_method = scheme - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") - - weight_dict = self.quant_method.get_weight(input_size_per_partition, - output_size_per_partition, - params_dtype) - - # Extract packing information (if present) - packed_dim = weight_dict.pop("_packed_dim", None) - packed_factor = weight_dict.pop("_packed_factor", None) - - for weight_name, weight_param in weight_dict.items(): - param = torch.nn.Parameter(weight_param, requires_grad=False) - set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) - - # Set packing attributes if the weight is packed - if packed_dim is not None and packed_factor is not None: - set_weight_attrs(param, { - "packed_dim": packed_dim, - "packed_factor": packed_factor - }) - - layer.register_parameter(weight_name, param) - set_weight_attrs(param, extra_weight_attrs) - - pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) - for pertensor_name, pertensor_param in pertensor_dict.items(): - param = PerTensorScaleParameter(data=pertensor_param, - weight_loader=weight_loader) - # disable warning - param.ignore_warning = True - layer.register_parameter(pertensor_name, param) - param.weight_loader = extra_weight_attrs.get("weight_loader") - - perchannel_dict = self.quant_method.get_perchannel_param( - output_size_per_partition, params_dtype) - for perchannel_name, perchannel_param in perchannel_dict.items(): - param = torch.nn.Parameter(perchannel_param, requires_grad=False) - set_weight_attrs(param, {"output_dim": 0}) - layer.register_parameter(perchannel_name, param) - set_weight_attrs(param, extra_weight_attrs) - - # NOTE: In w4a8 quantization implementation, - # for down_proj and o_proj scale_bias shape is [output_size, 16], - # others are [output_size, 1] - layer_type = "row" if isinstance(layer, - RowParallelLinear) else "others" - - pergroup_dict = self.quant_method.get_pergroup_param( - input_size_per_partition, - output_size_per_partition, - params_dtype, - layer_type=layer_type) - for pergroup_name, pergroup_param in pergroup_dict.items(): - param = torch.nn.Parameter(pergroup_param, requires_grad=False) - set_weight_attrs(param, {"output_dim": 0}) - layer.register_parameter(pergroup_name, param) - set_weight_attrs(param, extra_weight_attrs) - if "weight_scale_second" in pergroup_name or "weight_offset_second" in pergroup_name \ - or is_mx_quant_type(self.quant_method): - setattr(param, "input_dim", 1) - param.input_dim = 1 - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if hasattr(self.quant_method, "process_weights_after_loading"): - self.quant_method.process_weights_after_loading(layer) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if isinstance(layer, RowParallelLinear): - if layer.prefix.find("o_proj") != -1 and oproj_tp_enable(): - tp_rank = get_otp_group().rank_in_group - elif layer.prefix.find("down_proj") != -1 and mlp_tp_enable(): - tp_rank = get_mlp_tp_group().rank_in_group - elif (layer.prefix.find("o_proj") != -1 or - layer.prefix.find("out_proj") != -1) and flashcomm2_enable(): - if get_ascend_config( - ).flashcomm2_oproj_tensor_parallel_size == 1: - tp_rank = 0 - else: - tp_rank = get_flashcomm2_otp_group().rank_in_group - else: - tp_rank = get_tensor_model_parallel_rank() - else: - tp_rank = 0 - return self.quant_method.apply(layer, x, bias, tp_rank) - - -class AscendKVCacheMethod(BaseKVCacheMethod): - """KVCache method for Ascend quantization. - - This wrapper class delegates to the actual attention quantization scheme. - - Args: - scheme: The attention quantization scheme instance. - """ - - def __init__(self, scheme: AscendAttentionScheme) -> None: - self.quant_method = scheme - - def create_weights(self, layer: torch.nn.Module) -> None: - # Different from linear method, there are no weight processing/slicing - # steps for attention in vllm. So the whole process of create weights - # is hidden into the specific quant method. - self.quant_method.create_weights(layer) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - self.quant_method.process_weights_after_loading(layer) - - def apply(self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata, - attn_type, scale, output) -> torch.Tensor: - return self.quant_method.apply(layer, query, key, value, kv_cache, - attn_metadata, attn_type, scale, output) - - -class AscendFusedMoEMethod(FusedMoEMethodBase): - """FusedMoE method for Ascend quantization. - - This wrapper class delegates to the actual MoE quantization scheme. - - Args: - scheme: The MoE quantization scheme instance. - moe_config: The FusedMoE configuration. - """ - - def __init__(self, scheme: AscendMoEScheme, - moe_config: FusedMoEConfig) -> None: - super().__init__(moe_config) - self.quant_method = scheme - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - weight_param = self.quant_method.get_weight( - num_experts, intermediate_size_per_partition, hidden_size, - params_dtype) - for param_key, param_value in weight_param.items(): - param = torch.nn.Parameter(param_value, requires_grad=False) - layer.register_parameter(param_key, param) - set_weight_attrs(param, extra_weight_attrs) - - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - per_group_param = [ - "weight_scale_second", "weight_offset_second", "scale_bias" - ] + ["weight_scale", "weight_offset"] if hasattr( - self.quant_method, - "group_size") and self.quant_method.group_size > 0 else [] - dynamic_quant_param = self.quant_method.get_dynamic_quant_param( - num_experts, intermediate_size_per_partition, hidden_size, - params_dtype) - for param_key, param_value in dynamic_quant_param.items(): - param = torch.nn.Parameter(param_value, requires_grad=False) - layer.register_parameter(param_key, param) - set_weight_attrs(param, extra_weight_attrs) - if any(fields in param_key for fields in per_group_param): - setattr(param, "quant_method", - FusedMoeWeightScaleSupported.GROUP.value) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: bool = True, - enable_force_load_balance: bool = False, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num=0, - **kwargs, - ) -> torch.Tensor: - return self.quant_method.apply( - layer, x, router_logits, top_k, renormalize, use_grouped_topk, - global_num_experts, expert_map, topk_group, num_expert_group, - custom_routing_function, scoring_func, routed_scaling_factor, - e_score_correction_bias, is_prefill, enable_force_load_balance, - log2phy, global_redundant_expert_num, **kwargs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if hasattr(self.quant_method, "process_weights_after_loading"): - self.quant_method.process_weights_after_loading(layer) - - @property - def supports_eplb(self): - supports_eplb = getattr(self.quant_method, "supports_eplb", False) - return supports_eplb - - -class AscendEmbeddingMethod(AscendLinearMethod): - """Embedding method for Ascend quantization. - - This is essentially the same as AscendLinearMethod, just with a different name - for clarity when used with VocabParallelEmbedding layers. - - Args: - scheme: The quantization scheme instance. - """ - - def __init__(self, scheme: AscendLinearScheme) -> None: - self.quant_method = scheme From fb73714b2ebf3aac0cf2b53915c097b15864ad23 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Thu, 15 Jan 2026 11:11:13 +0800 Subject: [PATCH 16/20] fix type checking Signed-off-by: SlightwindSec --- vllm_ascend/quantization/compressed_tensors_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/quantization/compressed_tensors_config.py b/vllm_ascend/quantization/compressed_tensors_config.py index 8aa1af1fc50..30834e701f4 100644 --- a/vllm_ascend/quantization/compressed_tensors_config.py +++ b/vllm_ascend/quantization/compressed_tensors_config.py @@ -217,12 +217,13 @@ def _get_linear_scheme( if weight_quant is None: return None - return self._create_scheme_for_layer_type( + scheme = self._create_scheme_for_layer_type( weight_quant=weight_quant, input_quant=input_quant, format=format, layer_type="linear", ) + return cast(AscendLinearScheme, scheme) def _get_moe_scheme( self, @@ -242,12 +243,13 @@ def _get_moe_scheme( if weight_quant is None: return None - return self._create_scheme_for_layer_type( + scheme = self._create_scheme_for_layer_type( weight_quant=weight_quant, input_quant=input_quant, format=format, layer_type="moe", ) + return cast(AscendMoEScheme, scheme) def _get_quant_args( self, From 390afac02d067cb1084db71070eddb26b1bce125 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Thu, 15 Jan 2026 16:46:20 +0800 Subject: [PATCH 17/20] rename eplb config Signed-off-by: SlightwindSec --- vllm_ascend/quantization/methods/w4a16.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_ascend/quantization/methods/w4a16.py b/vllm_ascend/quantization/methods/w4a16.py index c81b840098b..14752b044e6 100644 --- a/vllm_ascend/quantization/methods/w4a16.py +++ b/vllm_ascend/quantization/methods/w4a16.py @@ -117,8 +117,7 @@ def __init__(self) -> None: vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 32) - ascend_config = get_ascend_config() - self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb or ascend_config.eplb_config.expert_map_record_path + self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb def get_weight( self, From 814992ee0c41dbdf47e824a384faa8c08b85a27b Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Tue, 20 Jan 2026 16:00:20 +0800 Subject: [PATCH 18/20] run ci Signed-off-by: SlightwindSec --- vllm_ascend/quantization/method_adapters.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py index beaf78a99c5..b882e6d8291 100644 --- a/vllm_ascend/quantization/method_adapters.py +++ b/vllm_ascend/quantization/method_adapters.py @@ -15,18 +15,6 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -"""Wrapper classes that delegate to actual quantization scheme implementations. - -These wrapper classes (AscendLinearMethod, AscendFusedMoEMethod, etc.) implement -the vLLM QuantizeMethodBase interface and delegate the actual quantization -operations to scheme implementations (AscendLinearScheme, AscendMoEScheme). - -The wrapper classes handle: -- Weight creation and registration -- Parameter attribute setting -- Tensor parallel rank handling -- Delegation to the underlying scheme's apply() method -""" from typing import Callable, List, Optional From c2fb6cfbf72e21972de54a29d580441d127407f1 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Thu, 22 Jan 2026 19:57:40 +0800 Subject: [PATCH 19/20] rm utils.py Signed-off-by: SlightwindSec --- vllm_ascend/quantization/utils.py | 128 ------------------------------ 1 file changed, 128 deletions(-) delete mode 100644 vllm_ascend/quantization/utils.py diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py deleted file mode 100644 index 128c1d5d424..00000000000 --- a/vllm_ascend/quantization/utils.py +++ /dev/null @@ -1,128 +0,0 @@ -from typing import Any, Dict, Optional, Type - -import torch -from vllm.logger import logger - -from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD - -from .methods import ( - AscendW4A4FlatQuantDynamicLinearMethod, - AscendW4A4LaosDynamicLinearMethod, - AscendW4A8DynamicFusedMoEMethod, - AscendW4A8DynamicLinearMethod, - AscendW4A16FusedMoEMethod, - AscendW8A8LinearMethod, - AscendW8A8DynamicFusedMoEMethod, - AscendW8A8DynamicLinearMethod, - AscendW8A8PDMixFusedMoeMethod, - AscendW8A8PDMixLinearMethod, - AscendW8A8MXFP8DynamicLinearMethod, - AscendW8A16LinearMethod, - is_mx_quant_type, -) - -ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { - "W4A16": { - "moe": AscendW4A16FusedMoEMethod, - }, - "W4A8_DYNAMIC": { - "linear": AscendW4A8DynamicLinearMethod, - "moe": AscendW4A8DynamicFusedMoEMethod, - }, - "W4A4_DYNAMIC": { - "linear": AscendW4A4LaosDynamicLinearMethod, - }, - "W4A4_FLATQUANT_DYNAMIC": { - "linear": AscendW4A4FlatQuantDynamicLinearMethod, - }, - "W8A8": { - "linear": AscendW8A8LinearMethod, - }, - "W8A8_DYNAMIC": { - "linear": AscendW8A8DynamicLinearMethod, - "moe": AscendW8A8DynamicFusedMoEMethod, - }, - "W8A8_MIX": { - "linear": AscendW8A8PDMixLinearMethod, - "moe": AscendW8A8PDMixFusedMoeMethod, - }, - "W8A16": { - "linear": AscendW8A16LinearMethod, - }, - "W8A8_MXFP8": { - "linear": AscendW8A8MXFP8DynamicLinearMethod, - }, -} - - -def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, - packed_modules_mapping: Dict[str, Any]): - proj_name = prefix.split(".")[-1] - if proj_name in packed_modules_mapping: - quant_type = None - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in packed_modules_mapping[proj_name] - ] - for shard_prefix in shard_prefixes: - shard_quant_type = quant_description[shard_prefix + '.weight'] - - if quant_type is None: - quant_type = shard_quant_type - elif shard_quant_type != quant_type: - raise ValueError( - f"Not all shards of {prefix} are quantized with same quant type." - f"Shard {proj_name} uses {shard_quant_type}, but another shard" - f"use {quant_type}. Please check quantization config.") - else: - quant_type = quant_description[prefix + '.weight'] - return quant_type - - -def get_quant_method(quant_description: Dict[str, Any], - prefix: str, - layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None, - layer: torch.nn.Module = None): - if quant_description.get("quant_method") == COMPRESSED_TENSORS_METHOD: - return get_quant_method_llmcompressor(layer) - - return get_quant_method_modelslim(quant_description, prefix, layer_type, - packed_modules_mapping) - - -def get_quant_method_llmcompressor(layer: torch.nn.Module): - logger.info_once("Using the vLLM Ascend llmcompressor Quantization now!") - if layer.scheme is None: - raise ValueError("A scheme must be defined for each layer") - return layer.scheme - - -def get_quant_method_modelslim( - quant_description: Dict[str, Any], - prefix: str, - layer_type: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): - logger.info_once("Using the vLLM Ascend modelslim Quantization now!") - if packed_modules_mapping is None: - packed_modules_mapping = dict() - # Attention - if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): - quant_type = quant_description['fa_quant_type'] - # Linear - else: - quant_type = get_linear_quant_type(quant_description, prefix, - packed_modules_mapping) - if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys(): - method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type] - if layer_type in method_map.keys(): - method_cls = method_map[layer_type] - return method_cls() - else: - raise NotImplementedError( - f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." - ) - raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ - f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}") - - From 2fa1d6ae7f8154e13fadb8e6f805b19fafa03a52 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Fri, 23 Jan 2026 10:24:26 +0800 Subject: [PATCH 20/20] rename W4A4_DYNAMIC Signed-off-by: SlightwindSec --- vllm_ascend/quantization/methods/w4a4_laos_dynamic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py b/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py index 7d35235335c..3fc83b25309 100644 --- a/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py +++ b/vllm_ascend/quantization/methods/w4a4_laos_dynamic.py @@ -24,9 +24,9 @@ from .registry import register_scheme -@register_scheme("W4A4_LAOS_DYNAMIC", "linear") +@register_scheme("W4A4_DYNAMIC", "linear") class AscendW4A4LaosDynamicLinearMethod(AscendLinearScheme): - """Linear method for Ascend W4A4_LAOS_DYNAMIC. + """Linear method for Ascend W4A4_DYNAMIC. This class implements W4A4 quantization with LAOS approach and dynamic activation quantization. - Weight: 4-bit quantization (per-channel) with scale and offset, stored as int8.