diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 4ff1ee25f71..96baf4ffc8b 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -266,6 +266,7 @@ jobs: run: | pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC + pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Kimi_K2_Thinking_W4A16 # pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP # pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_W8A8_WITH_EP pytest -sv tests/e2e/multicard/test_data_parallel_tp2.py diff --git a/docs/source/tutorials/index.md b/docs/source/tutorials/index.md index e365b5822e8..b5ec8f4f648 100644 --- a/docs/source/tutorials/index.md +++ b/docs/source/tutorials/index.md @@ -12,6 +12,7 @@ single_npu_qwen3_w4a4 single_node_pd_disaggregation_mooncake multi_npu_qwen3_next multi_npu +multi_npu_kimi-k2-thinking multi_npu_moge multi_npu_qwen3_moe multi_npu_quantization diff --git a/docs/source/tutorials/multi_npu_kimi-k2-thinking.md b/docs/source/tutorials/multi_npu_kimi-k2-thinking.md new file mode 100644 index 00000000000..6a776f45eb3 --- /dev/null +++ b/docs/source/tutorials/multi_npu_kimi-k2-thinking.md @@ -0,0 +1,107 @@ +# Multi-NPU (Kimi-K2-Thinking) + +## Run with Docker + +```{code-block} bash + :substitutions: +# Update the vllm-ascend image +export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:|vllm_ascend_version| +export NAME=vllm-ascend + +# Run the container using the defined variables +# Note: If you are running bridge network with docker, please expose available ports for multiple nodes communication in advance +docker run --rm \ +--name $NAME \ +--net=host \ +--shm-size=1g \ +--device /dev/davinci0 \ +--device /dev/davinci1 \ +--device /dev/davinci2 \ +--device /dev/davinci3 \ +--device /dev/davinci4 \ +--device /dev/davinci5 \ +--device /dev/davinci6 \ +--device /dev/davinci7 \ +--device /dev/davinci8 \ +--device /dev/davinci9 \ +--device /dev/davinci10 \ +--device /dev/davinci11 \ +--device /dev/davinci12 \ +--device /dev/davinci13 \ +--device /dev/davinci14 \ +--device /dev/davinci15 \ +--device /dev/davinci_manager \ +--device /dev/devmm_svm \ +--device /dev/hisi_hdc \ +-v /usr/local/dcmi:/usr/local/dcmi \ +-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \ +-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ +-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ +-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ +-v /etc/ascend_install.info:/etc/ascend_install.info \ +-v /mnt/sfs_turbo/.cache:/home/cache \ +-it $IMAGE bash +``` + +## Verify the Quantized Model +Please be advised to edit the value of `"quantization_config.config_groups.group_0.targets"` from `["Linear"]` into `["MoE"]` in `config.json` of original model downloaded from [Hugging Face](https://huggingface.co/moonshotai/Kimi-K2-Thinking). + +```json +{ + "quantization_config": { + "config_groups": { + "group_0": { + "targets": [ + "MoE" + ] + } + } + } +} +``` + +Your model files look like: + +```bash +. +|-- chat_template.jinja +|-- config.json +|-- configuration_deepseek.py +|-- configuration.json +|-- generation_config.json +|-- model-00001-of-000062.safetensors +|-- ... +|-- model-00062-of-000062.safetensors +|-- model.safetensors.index.json +|-- modeling_deepseek.py +|-- tiktoken.model +|-- tokenization_kimi.py +`-- tokenizer_config.json +``` + +## Online Inference on Multi-NPU + +Run the following script to start the vLLM server on Multi-NPU: + +For an Atlas 800 A3 (64G*16) node, tensor-parallel-size should be at least 16. + +```bash +vllm serve Kimi-K2-Thinking \ +--served-model-name kimi-k2-thinking \ +--tensor-parallel-size 16 \ +--enable_expert_parallel \ +--trust-remote-code \ +--no-enable-prefix-caching +``` + +Once your server is started, you can query the model with input prompts. + +```bash +curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "model": "kimi-k2-thinking", + "messages": [ + {"role": "user", "content": "Who are you?"} + ], + "temperature": 1.0 +}' +``` diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index 67c87332d96..529cc952a62 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -49,6 +49,10 @@ "vllm-ascend/DeepSeek-V3.1-W4A8-puring" ] +KIMI_W4A16_MODELS = [ + "vllm-ascend/Kimi-K2-Thinking-Pruning", +] + def test_models_distributed_QwQ(): example_prompts = [ @@ -250,3 +254,24 @@ def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(model): quantization="ascend", ) as vllm_model: vllm_model.generate_greedy(example_prompts, max_tokens) + + +@pytest.mark.parametrize("model", KIMI_W4A16_MODELS) +def test_models_distributed_Kimi_K2_Thinking_W4A16(model): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + model, + max_model_len=8192, + dtype="auto", + tensor_parallel_size=4, + enable_expert_parallel=True, + compilation_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + "cudagraph_capture_sizes": [1], + }, + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/ut/quantization/test_w4a16.py b/tests/ut/quantization/test_w4a16.py new file mode 100644 index 00000000000..5d50e738904 --- /dev/null +++ b/tests/ut/quantization/test_w4a16.py @@ -0,0 +1,269 @@ +from unittest.mock import Mock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.quantization.w4a16 import (AscendW4A16FusedMoEMethod, + pack_to_int32, unpack_from_int32) + + +class TestUnpackFromInt32(TestBase): + + def test_unpack_from_int32_packed_dim_1(self): + weight = torch.tensor([[305419896, -1420531520]], dtype=torch.int32) + shape = torch.Size([1, 8]) + num_bits = 4 + + result = unpack_from_int32(weight, shape, num_bits, packed_dim=1) + + self.assertEqual(result.dtype, torch.int8) + self.assertEqual(result.shape, shape) + + def test_unpack_from_int32_packed_dim_0(self): + weight = torch.tensor([[305419896], [-1420531520]], dtype=torch.int32) + shape = torch.Size([8, 1]) + num_bits = 4 + + result = unpack_from_int32(weight, shape, num_bits, packed_dim=0) + + self.assertEqual(result.dtype, torch.int8) + self.assertEqual(result.shape, shape) + + def test_unpack_from_int32_assertions(self): + with self.assertRaises(AssertionError): + weight = torch.tensor([[1, 2]], dtype=torch.int64) + unpack_from_int32(weight, torch.Size([8, 1]), 4) + + with self.assertRaises(AssertionError): + weight = torch.tensor([[1, 2]], dtype=torch.int32) + unpack_from_int32(weight, torch.Size([8, 1]), 16) + + +class TestPackToInt32(TestBase): + + @patch( + "vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack" + ) + def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack): + mock_npu_convert_weight_to_int4pack.return_value = torch.zeros( + (2, 4), dtype=torch.int32) + + weight = torch.zeros((2, 8, 16), dtype=torch.int8) + result = pack_to_int32(weight) + + self.assertEqual(result.dtype, torch.int32) + mock_npu_convert_weight_to_int4pack.assert_not_called() + + self.assertEqual(result.shape, torch.Size([2, 8, 4])) + + @patch( + "vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack" + ) + def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack): + + def mock_convert_weight(weight): + return weight + + mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight + weight = torch.zeros((2, 8, 8), dtype=torch.int32) + result = pack_to_int32(weight) + + self.assertEqual(result.dtype, torch.int32) + self.assertEqual(result.shape, weight.shape) + + def test_pack_to_int32_assertion_dim(self): + with self.assertRaises(AssertionError): + weight = torch.zeros((8, 8), dtype=torch.int8) + pack_to_int32(weight) + + def test_pack_to_int32_assertion_dtype(self): + with self.assertRaises(AssertionError): + weight = torch.zeros((2, 8, 8), dtype=torch.float32) + pack_to_int32(weight) + + def test_pack_to_int32_assertion_divisible(self): + with self.assertRaises(AssertionError): + weight = torch.zeros((2, 8, 7), dtype=torch.int32) + pack_to_int32(weight) + + with self.assertRaises(AssertionError): + weight = torch.zeros((2, 8, 7), dtype=torch.int8) + pack_to_int32(weight) + + +class TestAscendW4A16FusedMoEMethod(TestBase): + experts = 8 + input_size = 32 + output_size = 128 + group_size = 32 + + @patch("vllm_ascend.quantization.w4a16.get_ascend_config") + @patch("vllm_ascend.quantization.w4a16.get_current_vllm_config") + def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config): + mock_ascend_config = Mock() + mock_ascend_config.dynamic_eplb = False + mock_ascend_config.expert_map_record_path = None + mock_get_ascend_config.return_value = mock_ascend_config + + mock_vllm_config = Mock() + mock_vllm_config.quant_config = Mock(quant_description={ + "group_size": self.group_size, + }) + mock_get_current_vllm_config.return_value = mock_vllm_config + + self.quant_method = AscendW4A16FusedMoEMethod() + + def test_init(self): + self.assertTrue(self.quant_method.transpose_weight) + self.assertEqual(self.quant_method.num_bits, 4) + self.assertEqual(self.quant_method.pack_factor, 8) + self.assertEqual(self.quant_method.group_size, self.group_size) + self.assertFalse(self.quant_method.dynamic_eplb) + + def test_get_weight(self): + param_dict = self.quant_method.get_weight(self.experts, + self.input_size, + self.output_size, + torch.bfloat16) + + self.assertEqual(param_dict["w13_weight_packed"].dtype, torch.int32) + expected_w13_shape = (self.experts, 2 * self.input_size, + self.output_size // + self.quant_method.pack_factor) + self.assertEqual(param_dict["w13_weight_packed"].shape, + expected_w13_shape) + + self.assertEqual(param_dict["w2_weight_packed"].dtype, torch.int32) + expected_w2_shape = (self.experts, self.output_size, + self.input_size // self.quant_method.pack_factor) + self.assertEqual(param_dict["w2_weight_packed"].shape, + expected_w2_shape) + + def test_get_dynamic_quant_param(self): + param_dict = self.quant_method.get_dynamic_quant_param( + self.experts, self.input_size, self.output_size, torch.bfloat16) + + self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) + expected_w13_scale_shape = (self.experts, 2 * self.input_size, + self.output_size // self.group_size) + self.assertEqual(param_dict["w13_weight_scale"].shape, + expected_w13_scale_shape) + + self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16) + expected_w2_scale_shape = (self.experts, self.output_size, + self.input_size // self.group_size) + self.assertEqual(param_dict["w2_weight_scale"].shape, + expected_w2_scale_shape) + + self.assertEqual(param_dict["w13_weight_shape"].dtype, torch.int32) + self.assertEqual(param_dict["w13_weight_shape"].shape, + (self.experts, 2)) + + self.assertEqual(param_dict["w2_weight_shape"].dtype, torch.int32) + self.assertEqual(param_dict["w2_weight_shape"].shape, + (self.experts, 2)) + + self.assertEqual(param_dict["w13_weight_offset"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w13_weight_offset"].shape, + expected_w13_scale_shape) + + self.assertEqual(param_dict["w2_weight_offset"].dtype, torch.bfloat16) + self.assertEqual(param_dict["w2_weight_offset"].shape, + expected_w2_scale_shape) + + def build_layer(self): + """Build a mock layer for testing""" + layer = torch.nn.Module() + + w13_shape = (self.experts, 2 * self.input_size, + self.output_size // self.quant_method.pack_factor) + w2_shape = (self.experts, self.output_size, + self.input_size // self.quant_method.pack_factor) + + layer.w13_weight_packed = torch.nn.Parameter(torch.randint( + -100, 100, w13_shape, dtype=torch.int32), + requires_grad=False) + layer.w2_weight_packed = torch.nn.Parameter(torch.randint( + -100, 100, w2_shape, dtype=torch.int32), + requires_grad=False) + + w13_scale_shape = (self.experts, 2 * self.input_size, + self.output_size // self.group_size) + w2_scale_shape = (self.experts, self.output_size, + self.input_size // self.group_size) + + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + w13_scale_shape, dtype=torch.bfloat16), + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones( + w2_scale_shape, dtype=torch.bfloat16), + requires_grad=False) + + layer.w13_weight_offset = torch.nn.Parameter(torch.zeros( + w13_scale_shape, dtype=torch.bfloat16), + requires_grad=False) + layer.w2_weight_offset = torch.nn.Parameter(torch.zeros( + w2_scale_shape, dtype=torch.bfloat16), + requires_grad=False) + + layer.w13_weight_shape = torch.nn.Parameter(torch.tensor( + [[2 * self.input_size, self.output_size]] * self.experts, + dtype=torch.int32), + requires_grad=False) + layer.w2_weight_shape = torch.nn.Parameter(torch.tensor( + [[self.output_size, self.input_size]] * self.experts, + dtype=torch.int32), + requires_grad=False) + + return layer + + @patch( + "vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack" + ) + def test_process_weights_after_loading_with_transpose( + self, mock_npu_convert_weight_to_int4pack): + + def mock_convert_weight(weight): + new_shape = list(weight.shape) + new_shape[-1] = new_shape[-1] // 8 + return torch.zeros(new_shape, dtype=torch.int32) + + mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight + + layer = self.build_layer() + self.quant_method.transpose_weight = True + + self.quant_method.process_weights_after_loading(layer) + + self.assertEqual(layer.w13_weight_packed.data.shape, + torch.Size([8, 128, 8])) + self.assertEqual(layer.w2_weight_packed.data.shape, + torch.Size([8, 32, 16])) + + self.assertEqual(layer.w13_weight_scale.data.shape, + torch.Size([8, 4, 64])) + self.assertEqual(layer.w2_weight_scale.data.shape, + torch.Size([8, 1, 128])) + self.assertEqual(layer.w13_weight_offset.data.shape, + torch.Size([8, 4, 64])) + self.assertEqual(layer.w2_weight_offset.data.shape, + torch.Size([8, 1, 128])) + + self.assertTrue(layer.w13_weight_scale.data.is_contiguous()) + self.assertTrue(layer.w2_weight_scale.data.is_contiguous()) + self.assertTrue(layer.w13_weight_offset.data.is_contiguous()) + self.assertTrue(layer.w2_weight_offset.data.is_contiguous()) + + def test_process_weights_after_loading_without_transpose(self): + layer = self.build_layer() + self.quant_method.transpose_weight = False + + original_w13_data = layer.w13_weight_packed.data.clone() + original_w2_data = layer.w2_weight_packed.data.clone() + + self.quant_method.process_weights_after_loading(layer) + + self.assertTrue( + torch.equal(layer.w13_weight_packed.data, original_w13_data)) + self.assertTrue( + torch.equal(layer.w2_weight_packed.data, original_w2_data)) diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 32604a1f95a..d0afa7bf24e 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -93,12 +93,15 @@ def fused_experts( apply_router_weight_on_input: bool = False, use_int8_w8a8: bool = False, use_int4_w4a8: bool = False, + use_int4_w4a16: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[list[torch.Tensor]] = None, w2_scale: Optional[list[torch.Tensor]] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, + w1_offset: Optional[torch.Tensor] = None, + w2_offset: Optional[torch.Tensor] = None, # For Cube/Vector parallel shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, @@ -147,9 +150,11 @@ def fused_experts( group_list_type=group_list_type, w1_scale_bias=w1_scale_bias, w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, topk_scales=topk_scales, with_quant=use_int8_w8a8 - or use_int4_w4a8, + or use_int4_w4a8 or use_int4_w4a16, fusion=use_int8_w8a8, need_trans=need_trans, dynamic_eplb=dynamic_eplb) @@ -275,12 +280,15 @@ def fused_experts( apply_router_weight_on_input: bool = False, use_int8_w8a8: bool = False, use_int4_w4a8: bool = False, + use_int4_w4a16: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, + w1_offset: Optional[torch.Tensor] = None, + w2_offset: Optional[torch.Tensor] = None, # For Cube/Vector parallel shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 3b182b1714b..3fc12644ea8 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -68,9 +68,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor, dynamic_scale: torch.Tensor = None, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, + w1_offset: Optional[torch.Tensor] = None, + w2_offset: Optional[torch.Tensor] = None, fusion: bool = False, dynamic_eplb: bool = False) -> torch.Tensor: - if dynamic_scale is None: + if w1_offset is not None: + unquantized_hidden_states = hidden_states + quantized_hidden_states = None + elif dynamic_scale is None: unquantized_hidden_states = hidden_states hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( hidden_states) @@ -79,6 +84,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, dispose_tensor(unquantized_hidden_states) quantized_hidden_states = None else: + unquantized_hidden_states = None pertoken_scale = dynamic_scale quantized_hidden_states = hidden_states @@ -90,7 +96,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor, weight_prefetch_method.maybe_prefetch_moe_weight_postprocess( hidden_states) is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 - if w1_scale_bias is None and is_mc2: + if w1_scale_bias is None and w1_offset is None and is_mc2: if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): # gmm1: gate_up_proj & act_fn: swiglu hidden_states, swiglu_out_scale, _ = ( @@ -149,6 +155,32 @@ def quant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, output_dtype=w2_scale[0].dtype)[0] + elif w1_offset is not None: + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[unquantized_hidden_states], + weight=[w1], + antiquant_scale=[w1_scale], + antiquant_offset=[w1_offset], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + dispose_tensor(unquantized_hidden_states) + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + antiquant_scale=[w2_scale], + antiquant_offset=[w2_offset], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] else: if w1_scale_bias is not None: if group_list_type == 0: @@ -269,6 +301,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor, group_list_type: int = 1, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None, + w1_offset: Optional[torch.Tensor] = None, + w2_offset: Optional[torch.Tensor] = None, topk_scales: Optional[torch.Tensor] = None, with_quant: bool = False, fusion: bool = False, @@ -286,6 +320,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor, group_list_type=group_list_type, w1_scale_bias=w1_scale_bias, w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, fusion=fusion, dynamic_eplb=dynamic_eplb) else: diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 71b0411c4bf..eb88e959773 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -65,8 +65,8 @@ def _rope_forward_oot( raise NotImplementedError( "Batched rotary embedding is currently not supported on NPU.") else: - if self.cos is not None and \ - self.sin is not None: + if hasattr(self, "cos") and hasattr(self, "sin") and \ + self.cos is not None and self.sin is not None: # If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation. # This method requires head_size and rotary_dim equal 128 and neox_style is True query = query.contiguous().view(1, query.shape[0], -1, diff --git a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py index f95ff7f0215..774bb00628e 100644 --- a/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm_ascend/quantization/compressed_tensors/compressed_tensors.py @@ -4,7 +4,8 @@ from compressed_tensors.quantization import (QuantizationArgs, QuantizationStrategy) from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import ( QUANTIZATION_METHODS, register_quantization_config) @@ -16,8 +17,11 @@ find_matched_target, is_activation_quantization_format, should_ignore_layer) -from vllm_ascend.quantization.quant_config import (AscendLinearMethod, +from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod, + AscendLinearMethod, AscendQuantConfig) +from vllm_ascend.quantization.w4a16 import AscendW4A16FusedMoEMethod from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD @@ -142,7 +146,7 @@ def get_quant_method( quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) # choose quantization method - quant_method: LinearMethodBase = UnquantizedLinearMethod() + quant_method = UnquantizedLinearMethod() if quant_scheme is not None: layer.scheme = quant_scheme ascend_quant_config = AscendQuantConfig(self.quant_description @@ -150,6 +154,21 @@ def get_quant_method( quant_method = AscendLinearMethod(ascend_quant_config, prefix, None, layer) return quant_method + if isinstance(layer, FusedMoE): + layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + + # choose quantization method + quant_method = AscendUnquantizedFusedMoEMethod(layer.moe_config) + if quant_scheme is not None: + layer.scheme = quant_scheme + ascend_quant_config = AscendQuantConfig(self.quant_description + or {}) + quant_method = AscendFusedMoEMethod( + ascend_quant_config, prefix, + ascend_quant_config.packed_modules_mapping, layer) + return quant_method return None def get_scheme(self, @@ -215,6 +234,10 @@ def _get_scheme_from_parts( if self._is_dynamic_token_w8a8(weight_quant, input_quant): return AscendW8A8DynamicLinearMethod() + if weight_quant is not None: + if self._is_w4a16(weight_quant): + return AscendW4A16FusedMoEMethod() + raise NotImplementedError( "No compressed-tensors compatible scheme was found.") @@ -246,6 +269,10 @@ def _is_dynamic_token_w8a8(self, weight_quant: QuantizationArgs, # Only symmetric weight quantization supported. return is_8_bits and is_token and is_symmetric and is_dynamic + def _is_w4a16(self, weight_quant: QuantizationArgs) -> bool: + is_4_bits = weight_quant.num_bits == 4 + return is_4_bits + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.target_scheme_map = hf_to_vllm_mapper.apply_dict( self.target_scheme_map) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index e19a008b115..6669fd2db6d 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -65,6 +65,9 @@ def __init__(self, quant_config: Dict[str, Any]): if "shared_head" in k: new_k = k.replace(".shared_head.", ".") extra_quant_dict[new_k] = self.quant_description[k] + if "weight_packed" in k: + new_k = k.replace("weight_packed", "weight") + extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) def __repr__(self) -> str: @@ -200,7 +203,8 @@ def get_scaled_act_names(self) -> List[str]: "kimi_k2": { "gate_up_proj": ["gate_proj", "up_proj"], "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] }, "deepseek_v32": { "gate_up_proj": ["gate_proj", "up_proj"], @@ -439,7 +443,9 @@ def create_weights( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) per_group_param = [ "weight_scale_second", "weight_offset_second", "scale_bias" - ] + ] + ["weight_scale", "weight_offset"] if hasattr( + self.quant_method, + "group_size") and self.quant_method.group_size > 0 else [] dynamic_quant_param = self.quant_method.get_dynamic_quant_param( num_experts, intermediate_size_per_partition, hidden_size, params_dtype) diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index be43726e8d9..1162de4720f 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -8,6 +8,7 @@ from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) +from .w4a16 import AscendW4A16FusedMoEMethod from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, @@ -16,6 +17,9 @@ AscendW8A8PDMixLinearMethod) ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { + "W4A16": { + "moe": AscendW4A16FusedMoEMethod, + }, "W4A8_DYNAMIC": { "linear": AscendW4A8DynamicLinearMethod, "moe": AscendW4A8DynamicFusedMoEMethod, diff --git a/vllm_ascend/quantization/w4a16.py b/vllm_ascend/quantization/w4a16.py new file mode 100644 index 00000000000..d15fa25aaa2 --- /dev/null +++ b/vllm_ascend/quantization/w4a16.py @@ -0,0 +1,284 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Callable, Dict, Optional + +import torch +import torch_npu +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ops.fused_moe.experts_selector import select_experts + + +def unpack_from_int32( + weight: torch.Tensor, + shape: torch.Size, + num_bits: int, + packed_dim: int = 1, +) -> torch.Tensor: + """ + Unpacks quantized weights from int32 format back to original bits. + + :param weight: The packed int32 tensor containing quantized weights + :param shape: Original shape to restore, defaults to None + :param num_bits: The number of bits used for quantization (<= 8) + :param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1 + :return: Unpacked tensor with int8 dtype after applying offset correction + """ + assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}." + assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}." + + pack_factor = 32 // num_bits + mask = (1 << num_bits) - 1 + + if packed_dim == 1: + unpacked_weight = torch.zeros( + (weight.shape[0], weight.shape[1] * pack_factor), + device=weight.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked_weight[:, i::pack_factor] = (weight >> + (num_bits * i)) & mask + original_row_size = int(shape[1]) + unpacked_weight = unpacked_weight[:, :original_row_size] + else: + unpacked_weight = torch.zeros( + (weight.shape[0] * pack_factor, weight.shape[1]), + device=weight.device, + dtype=torch.int32, + ) + for i in range(pack_factor): + unpacked_weight[i::pack_factor, :] = (weight >> + (num_bits * i)) & mask + original_row_size = int(shape[0]) + unpacked_weight = unpacked_weight[:original_row_size, :] + + offset = pow(2, num_bits) // 2 + unpacked_weight = (unpacked_weight - offset).to(torch.int8) + + return unpacked_weight + + +def pack_to_int32(weight: torch.Tensor) -> torch.Tensor: + """ + Packs quantized weights into int32 format for storage. + + :param weight: The 3D tensor to pack, must be int8 or int32 dtype + :return: Packed tensor with int32 dtype optimized for storage + """ + assert weight.dim( + ) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}." + assert weight.dtype in [ + torch.int8, torch.int32 + ], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}." + + if weight.dtype == torch.int32: + assert weight.shape[ + -1] % 8 == 0, "the last dim of weight needs to be divided by 8." + packed_weight = torch_npu.npu_convert_weight_to_int4pack( + weight.flatten(0, 1)) + packed_weight = packed_weight.view(weight.shape[0], weight.shape[1], + -1) + else: + assert weight.shape[ + -1] % 4 == 0, "the last dim of weight needs to be divided by 4." + packed_weight = weight.view(torch.int32).contiguous() + + return packed_weight + + +class AscendW4A16FusedMoEMethod: + """FusedMoe method for Ascend W4A16. + """ + + def __init__(self) -> None: + self.transpose_weight = True + self.num_bits = 4 # dtype = torch.int4 + self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32 + + vllm_config = get_current_vllm_config() + self.group_size = vllm_config.quant_config.quant_description.get( + "group_size", 32) + ascend_config = get_ascend_config() + self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path + + def get_weight( + self, + num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}" + assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}" + + param_dict = {} + + param_dict["w13_weight_packed"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.pack_factor, + dtype=torch.int32) + param_dict["w2_weight_packed"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32) + + return param_dict + + def get_dynamic_quant_param( + self, + num_experts: int, + intermediate_size_per_partition: int, + hidden_sizes: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}" + assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}" + + param_dict = {} + + param_dict["w13_weight_scale"] = torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.bfloat16) + param_dict["w2_weight_scale"] = torch.empty( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.bfloat16) + param_dict["w13_weight_shape"] = torch.empty(num_experts, + 2, + dtype=torch.int32) + param_dict["w2_weight_shape"] = torch.empty(num_experts, + 2, + dtype=torch.int32) + param_dict["w13_weight_offset"] = torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_sizes // self.group_size, + dtype=torch.bfloat16) + param_dict["w2_weight_offset"] = torch.zeros( + num_experts, + hidden_sizes, + intermediate_size_per_partition // self.group_size, + dtype=torch.bfloat16) + + return param_dict + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = True, + enable_force_load_balance: bool = True, + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + **kwargs, + ) -> torch.Tensor: + assert router_logits.shape[ + 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight_packed, + w2=layer.w2_weight_packed, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_offset=layer.w13_weight_offset, + w2_offset=layer.w2_weight_offset, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_int4_w4a16=True, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + quantized_x_for_share=quantized_x_for_share, + dynamic_scale_for_share=dynamic_scale_for_share, + dynamic_eplb=self.dynamic_eplb, + mc2_mask=kwargs.get("mc2_mask", None)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if self.transpose_weight: + w13_shape = layer.w13_weight_packed.data.shape + w2_shape = layer.w2_weight_packed.data.shape + unpacked_w13_weight = (unpack_from_int32( + layer.w13_weight_packed.data.flatten(0, 1), + torch.Size([ + w13_shape[0] * w13_shape[1], + w13_shape[2] * self.pack_factor + ]), + self.num_bits, + ).view(w13_shape[0], w13_shape[1], + -1).transpose(1, 2).contiguous().int()) + unpacked_w2_weight = (unpack_from_int32( + layer.w2_weight_packed.data.flatten(0, 1), + torch.Size([ + w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor + ]), + self.num_bits, + ).view(w2_shape[0], w2_shape[1], + -1).transpose(1, 2).contiguous().int()) + layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight) + layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight) + + layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose( + 1, 2).contiguous() + layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose( + 1, 2).contiguous() + + layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose( + 1, 2).contiguous() + layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose( + 1, 2).contiguous() diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 68c36094327..d903bcb1e13 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3471,13 +3471,13 @@ def _allocate_kv_cache_tensors( # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if self.model_config.is_deepseek_mla: + if self.model_config.use_mla: head_size = self.model_config.hf_text_config.qk_rope_head_dim + \ self.model_config.hf_text_config.kv_lora_rank dsa_k_cache_factor = None dsa_k_cache_size = None - if not self.model_config.is_deepseek_mla: + if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec k_tensor_split_factor = 2 v_tensor_split_factor = 2 @@ -3627,7 +3627,7 @@ def _reshape_kv_cache_tensors( kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - if not self.model_config.is_deepseek_mla: + if not self.model_config.use_mla: k_shape = kv_cache_shape[1:] v_shape = k_shape else: