From 447266b12f373816d742610fbc61f4cc58af1c76 Mon Sep 17 00:00:00 2001 From: pichangping <1337510399@qq.com> Date: Fri, 13 Mar 2026 15:27:35 +0800 Subject: [PATCH 01/13] fa_quant Co-authored-by: kunpengW-code <1289706727@qq.com> Co-authored-by: linsheng1 <1950916997@qq.com> Signed-off-by: pichangping <1337510399@qq.com> --- vllm_ascend/attention/mla_v1.py | 330 ++++++++++++------ vllm_ascend/attention/utils.py | 16 + .../kv_p2p/mooncake_layerwise_connector.py | 215 +++++++++--- vllm_ascend/ops/mla.py | 1 + vllm_ascend/patch/worker/__init__.py | 1 + .../patch/worker/patch_weight_utils.py | 82 +++++ vllm_ascend/quantization/methods/__init__.py | 2 + vllm_ascend/quantization/methods/kv_c8.py | 66 ++++ vllm_ascend/quantization/modelslim_config.py | 132 ++++++- vllm_ascend/worker/model_runner_v1.py | 50 ++- 10 files changed, 743 insertions(+), 152 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_weight_utils.py create mode 100644 vllm_ascend/quantization/methods/kv_c8.py diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index c68d4ff015d..47154468bb6 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -27,6 +27,7 @@ AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, enable_cp, + enabling_fa_quant, enabling_mlapo, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, @@ -47,7 +48,7 @@ register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.quantization.methods import AscendW8A8LinearMethod +from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -649,6 +650,7 @@ class DecodeMLAPreprocessResult(NamedTuple): k_nope: torch.Tensor | None = None k_pe: torch.Tensor | None = None decode_q_wo_k_up: torch.Tensor | None = None + dequant_scale_q_nope: torch.Tensor | None = None class PrefillMLAPreprocessResult(NamedTuple): @@ -716,6 +718,9 @@ def __init__( self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) + self.layer_name = kwargs.get("layer_name") + self.fa_quant_layer = enabling_fa_quant(self.vllm_config, self.layer_name) + self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype self.layer_sharding_kwargs = [] for layer_name in get_ascend_config().layer_sharding or []: if layer_name in kwargs: @@ -766,6 +771,8 @@ def update_graph_params( actual_seq_lengths, attn_output, softmax_lse, + dequant_scale_q_nope, + fak_descale_float, ) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model: @@ -784,27 +791,54 @@ def update_graph_params( seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu.npu_fused_infer_attention_score.out( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=num_heads, - num_key_value_heads=num_kv_heads, - input_layout=input_layout, - atten_mask=attn_mask, - sparse_mode=sparse_mode, - scale=scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=block_table, - block_size=block_size, - actual_seq_lengths_kv=seq_lens_list, - actual_seq_lengths=actual_seq_lengths, - workspace=graph_params.workspaces.get(num_tokens), - out=[attn_output, softmax_lse], - ) + if dequant_scale_q_nope is None: + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=attn_mask, + sparse_mode=sparse_mode, + scale=scale, + antiquant_mode=0, + antiquant_scale=None, + block_table=block_table, + block_size=block_size, + actual_seq_lengths_kv=seq_lens_list, + actual_seq_lengths=actual_seq_lengths, + workspace=graph_params.workspaces.get(num_tokens), + out=[attn_output, softmax_lse], + ) + else: + torch_npu.npu_fused_infer_attention_score_v2.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_query_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=attn_mask, + sparse_mode=sparse_mode, + softmax_scale=scale, + query_quant_mode=3, + key_quant_mode=0, + value_quant_mode=0, + dequant_scale_query=dequant_scale_q_nope, + dequant_scale_key=fak_descale_float, + dequant_scale_value=fak_descale_float, + block_table=block_table, + block_size=block_size, + actual_seq_kvlen=seq_lens_list, + actual_seq_qlen=actual_seq_lengths, + workspace=graph_params.workspaces.get(num_tokens), + out=[attn_output, softmax_lse], + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -878,6 +912,8 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): ) if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) + elif self.fa_quant_layer: + self._process_weights_for_fused_fa_quant() else: # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) @@ -886,6 +922,32 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) + def _process_weights_for_fused_fa_quant(self): + self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr] + self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr] + + wu_q = self.q_proj.weight.data + + self.wu_q = wu_q + + q_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] + + self.wd_q = q_a_proj_fa3 + + kv_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] + + self.wd_kv = kv_a_proj_fa3 + + self.dequant_scale_w_uq_qr = self.q_proj.weight_scale.data.view(1, -1).to(torch.float) + q_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr] + self.dequant_scale_w_dq = q_a_proj_deq_scl.view(1, -1).to(torch.float) + kv_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr] + self.dequant_scale_w_dkv_kr = kv_a_proj_deq_scl.view(1, -1).to(torch.float) + + layer = self.vllm_config.compilation_config.static_forward_context[self.layer_name] + self.quant_kscale = layer.quant_kscale + self.fak_descale_float = layer.fak_descale_float + def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] @@ -1166,6 +1228,7 @@ def _forward_decode( k_pe: torch.Tensor, block_size: int, attn_metadata: AscendMLAMetadata, + dequant_scale_q_nope=None, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None @@ -1173,7 +1236,15 @@ def _forward_decode( # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None - if self.enable_kv_nz: + if self.fa_quant_layer: + nz_fmt_last_dim = 16 + k_nope = k_nope.view( + -1, self.num_kv_heads, self.kv_lora_rank // (nz_fmt_last_dim * 2), block_size, nz_fmt_last_dim * 2 + ) + k_pe = k_pe.view( + -1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim + ) + elif self.enable_kv_nz: nz_fmt_last_dim = 16 k_nope = k_nope.view( -1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim @@ -1208,6 +1279,35 @@ def _forward_decode( sparse_mode = 3 attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q + elif self.fa_quant_layer: + attn_mask = None + input_layout = "BSND_NBSD" + q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous() + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1).contiguous() + dequant_scale_q_nope = dequant_scale_q_nope.view(num_tokens, 1, self.num_heads) + sparse_mode = 0 + actual_seq_lengths = None + common_kwargs_v2 = { + "query_rope": q_pe, + "key_rope": k_pe, + "num_query_heads": self.num_heads, + "num_key_value_heads": self.num_kv_heads, + "input_layout": input_layout, + "atten_mask": attn_mask, + "sparse_mode": sparse_mode, + "softmax_scale": self.scale, + "query_quant_mode": 3, + "key_quant_mode": 0, + "value_quant_mode": 0, + "dequant_scale_query": dequant_scale_q_nope, + "dequant_scale_key": self.fak_descale_float, + "dequant_scale_value": self.fak_descale_float, + "block_table": decode_meta.block_table, + "block_size": block_size, + "actual_seq_qlen": actual_seq_lengths, + "actual_seq_kvlen": decode_meta.seq_lens_list, + } + attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank) else: # The output layout is set to NBSD to eliminate the need for a # transpose operation after attention. @@ -1255,48 +1355,55 @@ def _forward_decode( graph_params.events[num_tokens].append(event) workspace = graph_params.workspaces.get(num_tokens) + attn_output = torch.empty(attn_output_shape, dtype=q_pe.dtype, device=q_pe.device) + softmax_lse = torch.empty(num_tokens, dtype=q_pe.dtype, device=q_pe.device) + attn_params = ( + weak_ref_tensors(q_nope), + weak_ref_tensors(k_nope), + weak_ref_tensors(q_pe), + weak_ref_tensors(k_pe), + self.num_heads, + self.num_kv_heads, + input_layout, + weak_ref_tensors(attn_mask) if attn_mask is not None else None, + sparse_mode, + self.scale, + decode_meta.block_table, + block_size, + decode_meta.seq_lens_list, + actual_seq_lengths, + weak_ref_tensors(attn_output), + weak_ref_tensors(softmax_lse), + ) + if self.fa_quant_layer: + get_max_workspace_func = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace + fused_infer_attention_func = torch_npu.npu_fused_infer_attention_score_v2.out + attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) + common_kwargs = common_kwargs_v2 + else: + get_max_workspace_func = torch_npu._npu_fused_infer_attention_score_get_max_workspace + fused_infer_attention_func = torch_npu.npu_fused_infer_attention_score.out + attn_params = attn_params + (None, None) + if workspace is None: - workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - q_nope, k_nope, k_nope, **common_kwargs - ) - if _EXTRA_CTX.is_draft_model: + workspace = get_max_workspace_func(q_nope, k_nope, k_nope, **common_kwargs) + if forward_context.is_draft_model: update_draft_graph_params_workspaces(num_tokens, workspace) else: update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device) - softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) - - graph_params.attn_params[num_tokens].append( - ( - weak_ref_tensors(q_nope), - weak_ref_tensors(k_nope), - weak_ref_tensors(q_pe), - weak_ref_tensors(k_pe), - self.num_heads, - self.num_kv_heads, - input_layout, - weak_ref_tensors(attn_mask) if attn_mask is not None else None, - sparse_mode, - self.scale, - decode_meta.block_table, - block_size, - decode_meta.seq_lens_list, - actual_seq_lengths, - weak_ref_tensors(attn_output), - weak_ref_tensors(softmax_lse), - ) - ) + graph_params.attn_params[num_tokens].append(attn_params) torch.npu.graph_task_group_begin(stream) - torch_npu.npu_fused_infer_attention_score.out( + fused_infer_attention_func( q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse] ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) + elif self.fa_quant_layer: + attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs_v2) else: attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs) - return self._v_up_proj(attn_output) def reorg_decode_q(self, decode_q_nope, decode_q_pe): @@ -1311,55 +1418,81 @@ def _mla_preprocess_only_decode(self, hidden_states, kv_cache, attn_metadata): sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] - decode_q_nope = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - decode_q_pe = torch.empty( - (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + dequant_scale_q_nope = None + if self.fa_quant_layer: + quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope = torch_npu.npu_mla_prolog_v2( + quantized_x, + self.wd_q, + self.wu_q, + self.W_UK_T, + self.wd_kv, + self.gamma1, + self.gamma2, + sin, + cos, + attn_metadata.slot_mapping[:bsz].to(torch.int64), + decode_k_nope, + decode_k_pe, + dequant_scale_x=pertoken_scale.view(-1, 1), + dequant_scale_w_dq=self.dequant_scale_w_dq, + dequant_scale_w_uq_qr=self.dequant_scale_w_uq_qr, + dequant_scale_w_dkv_kr=self.dequant_scale_w_dkv_kr, + quant_scale_ckv=self.quant_kscale, + cache_mode="PA_NZ", + ) + else: + decode_q_nope = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + decode_q_pe = torch.empty( + (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) - torch.ops._C_ascend.mla_preprocess( - hidden_states, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - cos, - sin, - self.W_UK_T, - decode_k_nope, - decode_k_pe, - attn_metadata.slot_mapping[:bsz], - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=self.ctkv_scale, - q_nope_scale=self.q_nope_scale, - cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", - quant_mode="per_tensor_quant_asymm", - q_out0=decode_q_nope, - kv_cache_out0=decode_k_nope, - q_out1=decode_q_pe, - kv_cache_out1=decode_k_pe, - enable_inner_out=False, - inner_out=torch.tensor([], device=hidden_states.device), - ) - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) - decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.W_UK_T, + decode_k_nope, + decode_k_pe, + attn_metadata.slot_mapping[:bsz], + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", + quant_mode="per_tensor_quant_asymm", + q_out0=decode_q_nope, + kv_cache_out0=decode_k_nope, + q_out1=decode_q_pe, + kv_cache_out1=decode_k_pe, + enable_inner_out=False, + inner_out=torch.tensor([], device=hidden_states.device), + ) + decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) + decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe) - decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) + decode_preprocess_res = DecodeMLAPreprocessResult( + decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope=dequant_scale_q_nope + ) return decode_preprocess_res, None def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata): @@ -1506,7 +1639,7 @@ def forward( o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess - if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + if self.fa_quant_layer or (self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS): hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv ) @@ -1542,6 +1675,7 @@ def forward( prefill_preprocess_res.value, kv_cache, attn_metadata, + decode_preprocess_res.dequant_scale_q_nope, ) o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 946d5c66d4e..f4c19727a88 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -331,3 +331,19 @@ def enabling_mlapo(vllm_config: VllmConfig) -> bool: and not vllm_config.kv_transfer_config.is_kv_producer ) return bool(envs.VLLM_ASCEND_ENABLE_MLAPO and is_decode_instance) + + +def enabling_fa_quant(vllm_config: VllmConfig, layer_name) -> bool: + is_decode_instance = ( + vllm_config.kv_transfer_config is not None + and vllm_config.kv_transfer_config.is_kv_consumer + and not vllm_config.kv_transfer_config.is_kv_producer + ) + quant_config = vllm_config.quant_config + enable_fa_quant = quant_config.enable_fa_quant if quant_config is not None else False + fa_quant_layer = False + if is_decode_instance and enable_fa_quant: + id = "".join(re.findall(r"\.(\d+)\.", layer_name)) + if int(id) in quant_config.kvcache_quant_layers: + fa_quant_layer = True + return fa_quant_layer diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 415bb6c9b76..74fce722373 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -124,6 +124,9 @@ class SendTask: # pd_head_ratio > 1 use k_cache: torch.Tensor | None = None v_cache: torch.Tensor | None = None + # kv cache quantization layer use + k_quant_cache: torch.Tensor | None = None + v_quant_cache: torch.Tensor | None = None layer_idx: int = 0 layer_name: str = "" # trans block info @@ -210,6 +213,9 @@ def __init__( use_mla: bool, k_buffer: torch.Tensor, v_buffer: torch.Tensor, + enable_kv_quant: bool, + k_quant_buffer: torch.Tensor | None, + v_quant_buffer: torch.Tensor | None, resharding_stream: torch.npu.Stream, callback_func: Callable[..., None] = lambda x: None, ): @@ -232,6 +238,9 @@ def __init__( self.send_queue = queue.Queue[SendTask]() self.k_buffer = k_buffer self.v_buffer = v_buffer + self.enable_kv_quant = enable_kv_quant + self.k_quant_buffer = k_quant_buffer + self.v_quant_buffer = v_quant_buffer self.ready_event = ready_event self.callback_func = callback_func @@ -325,19 +334,43 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta, grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( remote_block_ids, local_block_ids ) - for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( - zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) - ): - block_len = block_lens[k] - for group_remote_block_id, group_local_block_id in zip( - grouped_remote_block_ids, grouped_local_block_ids +# kv cache quantization scenario + if self.enable_kv_quant and send_task.k_quant_cache is not None: + assert len(block_lens) == 2, "Quantization block length must be 2!" + quant_block_lens = [block_lens[0] // 2, block_lens[1]] + layer_local_quant_kv_addr = [self.k_quant_buffer.data_ptr(), self.v_quant_buffer.data_ptr()] + rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] + # eg:[5,6,7,9] -> {5:0, 6:1, 7:2, 9:3} + rearrange_block_dict = { + value: index + for index, value in enumerate(rearrange_block_ids) # type:ignore + } + for block_len, src_layer_base_addr, dst_layer_base_addr in zip( + quant_block_lens, layer_local_quant_kv_addr, layer_remote_kv_base_addr ): - src = src_layer_base_addr + group_local_block_id[0] * block_len - dst = dst_layer_base_addr + group_remote_block_id[0] * block_len - length = len(group_local_block_id) * block_len - src_list.append(src) - dst_list.append(dst) - length_list.append(length) + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + rearrange_block_dict[group_local_block_id[0]] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + else: + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(layer_local_kv_base_addr, layer_remote_kv_base_addr) + ): + block_len = block_lens[k] + for group_remote_block_id, group_local_block_id in zip( + grouped_remote_block_ids, grouped_local_block_ids + ): + src = src_layer_base_addr + group_local_block_id[0] * block_len + dst = dst_layer_base_addr + group_remote_block_id[0] * block_len + length = len(group_local_block_id) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) else: rearrange_block_ids = send_task.group_rearrange_block_ids[layer_group_idx] rearrange_block_dict = { @@ -380,6 +413,14 @@ def _transfer_kv_cache(self, send_task: SendTask): value = value.view(-1, key.shape[-1]) # type:ignore self.k_buffer[: key.shape[0]].copy_(key) # [:4, 128] -> self.v_buffer[: value.shape[0]].copy_(value) + if send_task.k_quant_cache is not None: + with npu_stream_switch(self.resharding_stream): + key_quant = send_task.k_quant_cache + key_quant = key_quant.view(-1, key_quant.shape[-1]) # type:ignore + self.k_quant_buffer[: key_quant.shape[0]].copy_(key_quant) + value_quant = send_task.v_quant_cache + value_quant = value_quant.view(-1, value_quant.shape[-1]) # type:ignore + self.v_quant_buffer[: value_quant.shape[0]].copy_(value_quant) # Merge transmission tasks of the same session session_meta: dict[str, TransferMeta] = {} @@ -395,7 +436,9 @@ def _transfer_kv_cache(self, send_task: SendTask): session_meta[session_id].length.extend(length_list) session_meta[session_id].req_ids.append(req_id) - if self.pd_head_ratio == 1: + if send_task.k_quant_cache is not None: + self.resharding_stream.synchronize() + elif self.pd_head_ratio == 1: """ Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. This issue will be fixed in CANN version 8.5.rc1. @@ -628,7 +671,7 @@ def wait_for_layer_load(self, layer_name: str) -> None: self.connector_worker.wait_for_layer_load(layer_name) def save_kv_layer( - self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: "AttentionMetadata", **kwargs + self, layer_name: str, kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", **kwargs ) -> None: """MooncakeLayerwiseConnector does not save explicitly.""" assert self.connector_worker is not None @@ -962,10 +1005,13 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, engi self.layer_metadata: dict[str, LayerMetadata] = {} self.attn_resharding_group_idx = set[int]() + self.enable_kv_quant = ( + vllm_config.quant_config.enable_fa_quant if vllm_config.quant_config is not None else False + ) self.pd_head_ratio = get_ascend_config().pd_head_ratio self.num_head_replica = get_ascend_config().num_head_replica self.resharding_stream = None - if self.pd_head_ratio > 1: + if self.pd_head_ratio > 1 or self.enable_kv_quant: self.resharding_stream = torch.npu.Stream() self.remote_poller = zmq.Poller() # type: ignore @@ -985,11 +1031,15 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, engi self.timeout = 1.0 # seconds self.k_buffer: torch.Tensor | None = None self.v_buffer: torch.Tensor | None = None + self.k_quant_buffer: torch.Tensor | None = None + self.v_quant_buffer: torch.Tensor | None = None - def create_kv_buffer(self, first_kv_cache): + def create_kv_buffer(self, first_kv_cache_tuple): + alignment = 2 * 1024 * 1024 + buffer_list = [] + first_kv_cache = first_kv_cache_tuple[0] if self.pd_head_ratio > 1: # regesit kv buffer for tp inequal - alignment = 2 * 1024 * 1024 self.k_buffer = torch.zeros( first_kv_cache.numel() + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device ) @@ -1002,18 +1052,34 @@ def create_kv_buffer(self, first_kv_cache): self.v_buffer = align_memory(self.v_buffer, alignment)[: first_kv_cache.numel()].view( -1, first_kv_cache.shape[-1] ) + buffer_list.append(self.k_buffer) + buffer_list.append(self.v_buffer) + if self.enable_kv_quant: + quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2 + self.k_quant_buffer = torch.zeros( + quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device + ) + self.k_quant_buffer = align_memory(self.k_quant_buffer, alignment)[:quant_k_cache_numel].view( + -1, first_kv_cache.shape[-1] + ) + quant_v_cache_numel = first_kv_cache_tuple[1].numel() + self.v_quant_buffer = torch.zeros( + quant_v_cache_numel + alignment, dtype=first_kv_cache.dtype, device=first_kv_cache.device + ) + self.v_quant_buffer = align_memory(self.v_quant_buffer, alignment)[:quant_v_cache_numel].view( + -1, first_kv_cache_tuple[1].shape[-1] + ) + buffer_list.append(self.k_quant_buffer) + buffer_list.append(self.v_quant_buffer) - for tensor in (self.k_buffer, self.v_buffer): - assert tensor.data_ptr() % alignment == 0, ( - "The address of the registered kv cache should be aligned to 2M" - ) - ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) - logger.info( - f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} " - f"{tensor.numel()} {ret_value=}" - ) - if ret_value != 0: - raise RuntimeError("Mooncake memory registration failed. ") + for tensor in buffer_list: + assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M" + ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel()) + logger.info( + f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}" + ) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed. ") def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1042,8 +1108,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): ptrs = [] lengths = [] - use_resharding_buffer = False - resharding_buffer = None + use_kv_buffer = False + kv_buffer = None for layer_name, kv_cache_tuple in kv_caches.items(): if isinstance(kv_cache_tuple, (list, tuple)) is False: kv_cache_tuple = [kv_cache_tuple] @@ -1051,12 +1117,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): layer_kv_cache_spec = kv_cache_groups[layer_kv_group_id].kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] - if self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))): + if ( + self.pd_head_ratio > 1 and (isinstance(layer_kv_cache_spec, (FullAttentionSpec, SlidingWindowSpec))) + ) or self.enable_kv_quant: self.attn_resharding_group_idx.add(layer_kv_group_id) - if use_resharding_buffer is False: - use_resharding_buffer = True - resharding_buffer = kv_cache_tuple[0] - self.resharding_stream = torch.npu.Stream() + if use_kv_buffer is False: + use_kv_buffer = True + kv_buffer = kv_cache_tuple single_layer_meta = LayerMetadata([], [], [], []) for single_kv_cache in kv_cache_tuple: block_start_rank = 1 @@ -1092,8 +1159,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): lengths.append(kv_cache_tensor.size) global_te.register_buffer(ptrs, lengths) - if use_resharding_buffer: - self.create_kv_buffer(resharding_buffer) + if use_kv_buffer: + self.create_kv_buffer(kv_buffer) num_attn_module = 2 if self.vllm_config.model_config.hf_text_config.model_type == "longcat_flash" else 1 mtp_layer_name = "" @@ -1133,6 +1200,9 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): use_mla=self.use_mla, k_buffer=self.k_buffer, v_buffer=self.v_buffer, + enable_kv_quant=self.enable_kv_quant, + k_quant_buffer=self.k_quant_buffer, + v_quant_buffer=self.v_quant_buffer, resharding_stream=self.resharding_stream, callback_func=self.send_done_send_signal, ) @@ -1380,7 +1450,7 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): metadata.requests[req_id] = update_metadata[req_id] # update send task trans block info - if self.pd_head_ratio != 1: + if self.pd_head_ratio != 1 or self.enable_kv_quant: send_task = metadata.send_task send_task.group_rearrange_block_ids = [[] for _ in range(self.num_kv_cache_groups)] send_task.group_num_blocks = [0 for _ in range(self.num_kv_cache_groups)] @@ -1388,7 +1458,7 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): send_task.group_block_table = [None for _ in range(self.num_kv_cache_groups)] send_task.group_block_len_tensor = [None for _ in range(self.num_kv_cache_groups)] send_task.group_seq_start_tensor = [None for _ in range(self.num_kv_cache_groups)] - device = self.k_buffer.device # type: ignore + device = self.k_buffer.device if self.k_buffer is not None else self.k_quant_buffer.device # type: ignore for i in self.attn_resharding_group_idx: send_task.group_rearrange_block_ids[i].extend( sorted( @@ -1415,7 +1485,7 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): def save_kv_layer( self, layer_name: str, - kv_layer: tuple[torch.Tensor, torch.Tensor], + kv_layer: list[torch.Tensor], attn_metadata: "AttentionMetadata", connector_metadata: MooncakeLayerwiseConnectorMetadata, **kwargs, @@ -1490,12 +1560,51 @@ def save_kv_layer( values = values.reshape(-1, *kv_layer[1].shape[2:]) (keys, values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys, values) + quant_keys = None + quant_values = None + if self.enable_kv_quant and self.current_layer in self.vllm_config.quant_config.kvcache_quant_layers: + assert self.resharding_stream is not None + with npu_stream_switch(self.resharding_stream): + reshape_cache_event.wait() + device = self.k_quant_buffer.device # type: ignore + layer = self.vllm_config.compilation_config.static_forward_context[layer_name] + # Initialize buffers + # [num_tokens, kv_head, head_dim] + quant_key = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[0].size()[-2:]), + dtype=kv_layer[0].dtype, + device=device, + ) + quant_values = torch.empty( + (send_task.group_num_tokens[layer_group_idx], *kv_layer[1].size()[-2:]), + dtype=kv_layer[1].dtype, + device=device, + ) + + # Load cache data into buffers + torch_npu.atb.npu_paged_cache_load( + kv_layer[0], + kv_layer[1], + send_task.group_block_table[layer_group_idx], + send_task.group_block_len_tensor[layer_group_idx], + seq_starts=send_task.group_seq_start_tensor[layer_group_idx], + key=quant_key, + value=quant_values, + ) + quant_keys = torch.ops.vllm.quantize( + quant_key, layer.fak_descale, layer.fak_descale_reciprocal, layer.fak_offset + ) + quant_keys = self.trans_nd_to_nz(quant_keys, layer_group_idx) + quant_values = self.trans_nd_to_nz(quant_values, layer_group_idx) + assert self.kv_send_layer_thread is not None assert reshape_cache_event is not None layer_send_task = SendTask( wait_event=reshape_cache_event, k_cache=keys, v_cache=values, + k_quant_cache=quant_keys, + v_quant_cache=quant_values, layer_idx=self.current_layer, layer_name=layer_name, group_rearrange_block_ids=send_task.group_rearrange_block_ids, @@ -1510,6 +1619,32 @@ def save_kv_layer( self.kv_send_layer_thread.send_queue.put(layer_send_task) self.current_layer += 1 + def trans_nd_to_nz(self, cache_tensor: torch.Tensor, layer_group_idx: int): + head_num, head_dim = cache_tensor.shape[-2], cache_tensor.shape[-1] + cache_tensor = cache_tensor.view(-1, self.block_size[layer_group_idx], head_num * head_dim) + + batch = cache_tensor.shape[:-2] + a, b = cache_tensor.shape[-2], cache_tensor.shape[-1] + + dtype = cache_tensor.dtype + if dtype == torch.int8: + a0, b0 = 16, 32 + else: + a0, b0 = 16, 16 + + nz_shape = list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0] + + # Generate the axis order for the transpose operation. + offset = len(cache_tensor.shape) - 2 + base = [2, 0, 1, 3] + array_trans = [i for i in range(offset)] + [i + offset for i in base] + # Perform shape transformation and transpose operation. + *_, n1, m1, m0, n0 = nz_shape + cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) + cache_tensor = cache_tensor.permute(*array_trans) + cache_tensor = cache_tensor.reshape(-1, head_num, head_dim) + return cache_tensor + def _get_remote_socket(self, remote_host: str, remote_handshake_port: int) -> zmq.Socket: # type: ignore """Get a socket to the remote host.""" remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index c7ae60465c4..de2e2fa4705 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -120,6 +120,7 @@ def __init__( kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, kv_a_layernorm=mla_modules.kv_a_layernorm, o_proj=mla_modules.o_proj, + layer_name=f"{prefix}.attn", ) original_process_weights = self.mla_attn.process_weights_after_loading diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index a847cac29c1..22af4065cf0 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -26,6 +26,7 @@ import vllm_ascend.patch.worker.patch_qwen3_5 # noqa # isort: off +import vllm_ascend.patch.worker.patch_weight_utils # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa import vllm_ascend.patch.worker.patch_bert # noqa diff --git a/vllm_ascend/patch/worker/patch_weight_utils.py b/vllm_ascend/patch/worker/patch_weight_utils.py new file mode 100644 index 00000000000..26fddd04344 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_weight_utils.py @@ -0,0 +1,82 @@ +import sys + +from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name + + +class ImportPatchDecorator: + """Import patch decorator""" + + _patches = {} + + @classmethod + def register(cls, module_name): + """Decorator for registering module patches""" + + def decorator(func): + cls._patches[module_name] = func + return func + + return decorator + + @classmethod + def apply_patches(cls): + """Apply all patches""" + for module_name, patch_func in cls._patches.items(): + if module_name in sys.modules: + module = sys.modules[module_name] + try: + patch_func(module) + except Exception as e: + logger.error(f"Patch application failed {module_name}: {e}") + + +@ImportPatchDecorator.register("vllm.model_executor.models.deepseek_v2") +def patch_deepseek(module): + ori_maybe_remap_kv_scale_name = maybe_remap_kv_scale_name + + def new_remap(name: str, params_dict: dict): + name = ori_maybe_remap_kv_scale_name(name, params_dict) + + replace_scale_names = ["fa_q.scale", "fa_k.scale", "fa_v.scale", "fa_q.offset", "fa_k.offset", "fa_v.offset"] + + for scale_name in replace_scale_names: + if name.endswith(scale_name): + remap_name = name.replace(scale_name, f"mla_attn.mla_attn.{scale_name}") + if remap_name in params_dict: + return remap_name + else: + return remap_name.replace(".mla_attn", "") + + return name + + if hasattr(module, "maybe_remap_kv_scale_name"): + module._original_maybe_remap_kv_scale_name = module.maybe_remap_kv_scale_name + module.maybe_remap_kv_scale_name = new_remap + + +@ImportPatchDecorator.register("vllm.model_executor.model_loader.weight_utils") +def patch_weight_utils(module): + if "vllm.model_executor.models.deepseek_v2" in sys.modules: + deepseek = sys.modules["vllm.model_executor.models.deepseek_v2"] + if hasattr(deepseek, "maybe_remap_kv_scale_name"): + module.maybe_remap_kv_scale_name = deepseek.maybe_remap_kv_scale_name + + +original_import = __builtins__["__import__"] + + +def patched_import(name, globals=None, locals=None, fromlist=(), level=0): + module = original_import(name, globals, locals, fromlist, level) + + if name in ImportPatchDecorator._patches: + try: + ImportPatchDecorator._patches[name](module) + except Exception as e: + logger.error(f"Patch application failed during import {name}: {e}") + + return module + + +__builtins__["__import__"] = patched_import + +ImportPatchDecorator.apply_patches() diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index 3884029557b..84b6d773285 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -36,6 +36,7 @@ from .registry import get_scheme_class, register_scheme # Import all scheme classes for external access +from .kv_c8 import AscendFAQuantAttentionMethod from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod @@ -77,4 +78,5 @@ def is_mx_quant_type(instance: Any) -> bool: "AscendW4A16FusedMoEMethod", "AscendW4A4FlatQuantDynamicLinearMethod", "AscendW4A4LaosDynamicLinearMethod", + "AscendFAQuantAttentionMethod", ] diff --git a/vllm_ascend/quantization/methods/kv_c8.py b/vllm_ascend/quantization/methods/kv_c8.py new file mode 100644 index 00000000000..21d090a6abd --- /dev/null +++ b/vllm_ascend/quantization/methods/kv_c8.py @@ -0,0 +1,66 @@ +import torch +from vllm.config import get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size + +from .registry import register_scheme + + +def weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor): + """fa_q weight loader.""" + if param.numel() == 1 and loaded_weight.numel() == 1: + param.data.fill_(loaded_weight.item()) + else: + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + shard_size = loaded_weight.shape[0] // tp_size + loaded_weight = loaded_weight.narrow(0, shard_size * tp_rank, shard_size) + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) into parameter ({param.size()}) when TP is ({tp_size})" + ) + + param.data.copy_(loaded_weight) + + +@register_scheme("FAKQuant", "attention") +class AscendFAQuantAttentionMethod: + def __init__(self): + self.transpose_weight = True + self.printFlag = False + vllm_config = get_current_vllm_config() + config = vllm_config.model_config.hf_config + self.kv_lora_rank = getattr(config, "kv_lora_rank", 0) + self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + + def create_weights(self, layer: torch.nn.Module) -> None: + extra_module_names = ["fa_q", "fa_k", "fa_v"] + for name in extra_module_names: + setattr(layer, name, torch.nn.Module()) + params_dict = {} + dtype = torch.get_default_dtype() + layer.num_kv_heads = 1 + params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype) + params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) + params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) + params_dict["fa_q.offset"] = torch.empty((layer.num_heads, 1), dtype=torch.int8) + params_dict["fa_k.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8) + params_dict["fa_v.offset"] = torch.empty((layer.num_kv_heads, 1), dtype=torch.int8) + + for name, weight in params_dict.items(): + module_name, weight_name = name.rsplit(".", 1) + module = getattr(layer, module_name) + weight_param = torch.nn.Parameter(weight, requires_grad=False) + module.register_parameter(weight_name, weight_param) + # When loading weights, segment them according to TP + weight_param.weight_loader = weight_loader + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + fa_k_scale = torch.squeeze(layer.fa_k.scale).unsqueeze(0) + layer.fak_descale_float = torch.nn.Parameter(fa_k_scale.to(torch.float), requires_grad=False) + layer.fak_descale = torch.nn.Parameter(fa_k_scale, requires_grad=False) + layer.fak_descale_reciprocal = 1.0 / torch.nn.Parameter(fa_k_scale, requires_grad=False) + fa_k_offset = torch.squeeze(layer.fa_k.offset).unsqueeze(0) + layer.fak_offset = torch.nn.Parameter(fa_k_offset.to(layer.fak_descale.dtype), requires_grad=False) + + repeated_quant_kscale = fa_k_scale.repeat(self.kv_lora_rank) + layer.quant_kscale = repeated_quant_kscale.view(1, self.kv_lora_rank) + layer.quant_kscale = 1.0 / torch.nn.Parameter(layer.quant_kscale.to(torch.float), requires_grad=False) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 0945d1a3683..6cc520556b2 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -21,6 +21,10 @@ configs generated by the ModelSlim tool, along with model-specific mappings. """ +import glob +import json +import os +import re from collections.abc import Mapping from types import MappingProxyType from typing import Any, Optional @@ -28,6 +32,7 @@ import torch from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import register_quantization_config @@ -41,6 +46,8 @@ logger = init_logger(__name__) +MODELSLIM_CONFIG_FILENAME = "quant_model_description.json" + # key: model_type # value: vLLM prefix -> HF prefix mapping (used to convert vLLM layer names to HF format # for looking up keys in quant_model_description.json) @@ -397,9 +404,9 @@ class AscendModelSlimConfig(QuantizationConfig): quantized using the ModelSlim tool. """ - def __init__(self, quant_config: dict[str, Any]): + def __init__(self, quant_config: dict[str, Any] | None = None): super().__init__() - self.quant_description = quant_config + self.quant_description = quant_config if quant_config is not None else {} # TODO(whx): remove this adaptation after adding "shared_head" # to prefix of DeepSeekShareHead in vLLM. extra_quant_dict = {} @@ -415,6 +422,7 @@ def __init__(self, quant_config: dict[str, Any]): self.model_type: str | None = None self.hf_to_vllm_mapper: WeightsMapper | None = None self.vllm_to_hf_mapper: WeightsMapper | None = None + self._apply_extra_quant_adaptations() def __repr__(self) -> str: return "AscendModelSlimConfig:\n" + super().__repr__() @@ -433,7 +441,7 @@ def get_min_capability(cls) -> int: @classmethod def get_config_filenames(cls) -> list[str]: - return ["quant_model_description.json"] + return [""] @classmethod def from_config(cls, config: dict[str, Any]) -> "AscendModelSlimConfig": @@ -553,11 +561,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["Qua return AscendUnquantizedLinearMethod() scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) return AscendLinearMethod(scheme) - elif ( - isinstance(layer, Attention) - and "fa_quant_type" in self.quant_description - and self.quant_description["fa_quant_type"] is not None - ): + elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix): scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) return AscendKVCacheMethod(scheme) elif isinstance(layer, FusedMoE): @@ -604,5 +608,117 @@ def is_layer_skipped_ascend(self, prefix: str, fused_mapping: Mapping[str, list[ assert is_skipped is not None return is_skipped + def is_fa_quant_layer(self, prefix): + if self.enable_fa_quant: + _id = int("".join(re.findall(r"\.(\d+)\.", prefix))) + if _id in self.kvcache_quant_layers: + return True + return False + + def maybe_update_config(self, model_name: str) -> None: + """Load the ModelSlim quantization config from model directory. + + This method is called by vllm after get_quant_config() returns + successfully. Since we return an empty list from get_config_filenames() + to bypass vllm's built-in file lookup, we do the actual config loading + here and provide user-friendly error messages when the config is missing. + + Args: + model_name: Path to the model directory or model name. + """ + # If quant_description is already populated (e.g. from from_config()), + # there is nothing to do. + if self.quant_description: + return + + # Try to find and load the ModelSlim config file + if os.path.isdir(model_name): + config_path = os.path.join(model_name, MODELSLIM_CONFIG_FILENAME) + if os.path.isfile(config_path): + with open(config_path) as f: + self.quant_description = json.load(f) + self._apply_extra_quant_adaptations() + self._add_kvcache_quant_metadata() + return + + # Check if there are any json files at all to help diagnose + json_files = glob.glob(os.path.join(model_name, "*.json")) + json_names = [os.path.basename(f) for f in json_files] + else: + json_names = [] + + # Config file not found - raise a friendly error message + raise ValueError( + "\n" + + "=" * 80 + + "\n" + + "ERROR: ModelSlim Quantization Config Not Found\n" + + "=" * 80 + + "\n" + + "\n" + + f"You have enabled '--quantization {ASCEND_QUANTIZATION_METHOD}' " + + "(ModelSlim quantization),\n" + + f"but the model at '{model_name}' does not contain the required\n" + + f"quantization config file ('{MODELSLIM_CONFIG_FILENAME}').\n" + + "\n" + + "This usually means the model weights are NOT quantized by " + + "ModelSlim.\n" + + "\n" + + "Please choose one of the following solutions:\n" + + "\n" + + " Solution 1: Remove the quantization option " + + "(for float/unquantized models)\n" + + " " + + "-" * 58 + + "\n" + + f" Remove '--quantization {ASCEND_QUANTIZATION_METHOD}' from " + + "your command if you want to\n" + + " run the model with the original (float) weights.\n" + + "\n" + + " Example:\n" + + f" vllm serve {model_name}\n" + + "\n" + + " Solution 2: Quantize your model weights with ModelSlim first\n" + + " " + + "-" * 58 + + "\n" + + " Use the ModelSlim tool to quantize your model weights " + + "before deployment.\n" + + " After quantization, the model directory should contain " + + f"'{MODELSLIM_CONFIG_FILENAME}'.\n" + + " For more information, please refer to:\n" + + " https://gitee.com/ascend/msit/tree/master/msmodelslim\n" + + "\n" + + (f" (Found JSON files in model directory: {json_names})\n" if json_names else "") + + "=" * 80 + ) + + def _apply_extra_quant_adaptations(self) -> None: + """Apply extra adaptations to the quant_description dict. + + This handles known key transformations such as shared_head and + weight_packed mappings. + """ + extra_quant_dict = {} + for k in self.quant_description: + if "shared_head" in k: + new_k = k.replace(".shared_head.", ".") + extra_quant_dict[new_k] = self.quant_description[k] + if "weight_packed" in k: + new_k = k.replace("weight_packed", "weight") + extra_quant_dict[new_k] = self.quant_description[k] + self.quant_description.update(extra_quant_dict) + def get_scaled_act_names(self) -> list[str]: return [] + + def _add_kvcache_quant_metadata(self): + fa_quant_type = self.quant_description.get("fa_quant_type", "") + self.enable_fa_quant = fa_quant_type != "" + self.kvcache_quant_layers = [] + if self.enable_fa_quant: + for key in self.quant_description: + if "fa_k.scale" in key: + _id = "".join(re.findall(r"\.(\d+)\.", key)) + self.kvcache_quant_layers.append(int(_id)) + diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7631e2a97b1..b2c1e1a2644 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -393,6 +393,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.query_lens: torch.Tensor | None = None self.cpu_slot_mapping = None self.sampling_done_event: torch.npu.Event | None = None + self.kvbytes = {} @property def use_cp(self) -> bool: @@ -2665,10 +2666,16 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. if self.model_config.use_mla: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim - + self.model_config.hf_text_config.kv_lora_rank - ) + if layer_name in self.kvbytes: + head_size = ( + self.model_config.hf_text_config.qk_rope_head_dim * self.kvbytes[layer_name][1] + + self.model_config.hf_text_config.kv_lora_rank * self.kvbytes[layer_name][0] + ) + else: + head_size = ( + self.model_config.hf_text_config.qk_rope_head_dim + + self.model_config.hf_text_config.kv_lora_rank + ) dsa_k_cache_factor = None dsa_k_cache_size = None @@ -2683,6 +2690,13 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio() ] dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor) + elif layer_name in self.kvbytes: + k_tensor_split_factor = head_size / ( + self.model_config.hf_text_config.kv_lora_rank * self.kvbytes[layer_name][0] + ) + v_tensor_split_factor = head_size / ( + self.model_config.hf_text_config.qk_rope_head_dim * self.kvbytes[layer_name][1] + ) else: # for other deepseek models, use MLAAttentionSpec k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank @@ -2832,8 +2846,12 @@ def _reshape_kv_cache_tensors( num_kv_heads, self.model_config.hf_text_config.qk_rope_head_dim, ] - k_cache = raw_k_tensor.view(dtype).view(k_shape) - v_cache = raw_v_tensor.view(dtype).view(v_shape) + if layer_name in self.kvbytes: + k_cache = raw_k_tensor.view(dtype).view(k_shape) + v_cache = raw_v_tensor.view(self.vllm_config.model_config.dtype).view(v_shape) + else: + k_cache = raw_k_tensor.view(dtype).view(k_shape) + v_cache = raw_v_tensor.view(dtype).view(v_shape) if self.use_sparse and raw_dsa_k_tensor is not None: index_head_dim = self._get_sparse_kv_cache_ratio()[-1] @@ -3078,6 +3096,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # NOTE: Must process Attention/MLAAttention before MambaBase to maintain # ordering expected by graph parameter update logic in attention backends. mamba_layers: dict[str, MambaBase] = {} + + def dtype_to_bytes(dtype: torch.dtype) -> int: + """将 torch.dtype 转换为字节数""" + return torch.tensor([], dtype=dtype).element_size() + attn_layer_names = set() for layer_name, attn_module in attn_layers.items(): if isinstance(attn_module, Attention): @@ -3108,6 +3131,21 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, cache_dtype_str=self.vllm_config.cache_config.cache_dtype, ) + elif getattr(attn_module.impl, "fa_quant_layer", False): + block_size = self.vllm_config.cache_config.block_size + head_size = attn_module.head_size + attn_module.qk_rope_head_dim + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=head_size, + dtype=attn_module.impl.dtype, + cache_dtype_str=None, + ) + if layer_name not in self.kvbytes: + self.kvbytes[layer_name] = [ + dtype_to_bytes(attn_module.impl.dtype), + dtype_to_bytes(self.vllm_config.model_config.dtype), + ] elif spec := attn_module.get_kv_cache_spec(self.vllm_config): kv_cache_spec[layer_name] = spec From a6306bf37c4867dfa618f70b76ca21b9d26d4554 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Sat, 14 Mar 2026 14:48:37 +0800 Subject: [PATCH 02/13] fix conflict Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/quantization/modelslim_config.py | 38 +++++++++++++--- vllm_ascend/worker/model_runner_v1.py | 46 +++++++++++++++++--- 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index c682856bec2..4c3bc19a4a4 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -24,6 +24,7 @@ import glob import json import os +import re from collections.abc import Mapping from types import MappingProxyType from typing import Any, Optional @@ -31,12 +32,14 @@ import torch from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import register_quantization_config from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding from vllm.model_executor.models.utils import WeightsMapper +from vllm.utils.torch_utils import get_dtype_size from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD @@ -522,11 +525,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["Qua return AscendUnquantizedLinearMethod() scheme = create_scheme_for_layer(self.quant_description, prefix, "linear", self.packed_modules_mapping) return AscendLinearMethod(scheme) - elif ( - isinstance(layer, Attention) - and "fa_quant_type" in self.quant_description - and self.quant_description["fa_quant_type"] is not None - ): + elif isinstance(layer, AttentionLayerBase) and self.is_fa_quant_layer(prefix): scheme = create_scheme_for_layer(self.quant_description, prefix, "attention", self.packed_modules_mapping) return AscendKVCacheMethod(scheme) elif isinstance(layer, FusedMoE): @@ -573,6 +572,24 @@ def is_layer_skipped_ascend(self, prefix: str, fused_mapping: Mapping[str, list[ assert is_skipped is not None return is_skipped + def is_fa_quant_layer(self, prefix): + if self.enable_fa_quant: + _id = int("".join(re.findall(r"\.(\d+)\.", prefix))) + if _id in self.kvcache_quant_layers: + return True + return False + + def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config): + if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): + ori_dtype = model_config.dtype + quant_dtype = torch.int8 + # For MLA models like deepseek, we only quantify K cache to ensure accuracy + if model_config.use_mla: + return quant_dtype, ori_dtype + else: + return quant_dtype, quant_dtype + return cache_dtype, cache_dtype + def maybe_update_config(self, model_name: str, revision: str | None = None) -> None: """Load the ModelSlim quantization config from model directory. @@ -606,6 +623,7 @@ def maybe_update_config(self, model_name: str, revision: str | None = None) -> N with open(config_path) as f: self.quant_description = json.load(f) self._apply_extra_quant_adaptations() + self._add_kvcache_quant_metadata() return # Collect diagnostic info for the error message @@ -678,3 +696,13 @@ def _apply_extra_quant_adaptations(self) -> None: def get_scaled_act_names(self) -> list[str]: return [] + + def _add_kvcache_quant_metadata(self): + fa_quant_type = self.quant_description.get("fa_quant_type", "") + self.enable_fa_quant = fa_quant_type != "" + self.kvcache_quant_layers = [] + if self.enable_fa_quant: + for key in self.quant_description: + if "fa_k.scale" in key: + _id = "".join(re.findall(r"\.(\d+)\.", key)) + self.kvcache_quant_layers.append(int(_id)) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8780b4d0fe8..695beffc157 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -407,6 +407,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.query_lens: torch.Tensor | None = None self.cpu_slot_mapping = None self.sampling_done_event: torch.npu.Event | None = None + self.kvbytes = {} if vllm_version_is("0.17.0"): # self.cudagraph_batch_sizes sorts in ascending order. @@ -2661,10 +2662,16 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. if self.model_config.use_mla: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim - + self.model_config.hf_text_config.kv_lora_rank - ) + if layer_name in self.kvbytes: + head_size = ( + self.model_config.hf_text_config.qk_rope_head_dim * self.kvbytes[layer_name][1] + + self.model_config.hf_text_config.kv_lora_rank * self.kvbytes[layer_name][0] + ) + else: + head_size = ( + self.model_config.hf_text_config.qk_rope_head_dim + + self.model_config.hf_text_config.kv_lora_rank + ) if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec @@ -2678,6 +2685,13 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str v_tensor_split_factor = sparse_kv_cache_ratio[1] dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] + elif layer_name in self.kvbytes: + k_tensor_split_factor = head_size / ( + self.model_config.hf_text_config.kv_lora_rank * self.kvbytes[layer_name][0] + ) + v_tensor_split_factor = head_size / ( + self.model_config.hf_text_config.qk_rope_head_dim * self.kvbytes[layer_name][1] + ) else: # for other deepseek models, use MLAAttentionSpec k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank @@ -2858,8 +2872,13 @@ def _reshape_kv_cache_tensors( num_kv_heads, self.model_config.hf_text_config.qk_rope_head_dim, ] - k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape) - v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape) + k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype + if self.vllm_config.quant_config is not None: + k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( + layer_name, kv_cache_spec.dtype, self.model_config + ) + k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape) + v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape) if self.use_sparse: dsa_k_cache_shape = ( @@ -3157,6 +3176,21 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: cache_dtype_str=self.vllm_config.cache_config.cache_dtype, cache_sparse_c8=self.use_sparse_c8_indexer, ) + elif getattr(attn_module.impl, "fa_quant_layer", False): + block_size = self.vllm_config.cache_config.block_size + head_size = attn_module.head_size + attn_module.qk_rope_head_dim + kv_cache_spec[layer_name] = AscendMLAAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=head_size, + dtype=attn_module.impl.dtype, + cache_dtype_str=None, + ) + if layer_name not in self.kvbytes: + self.kvbytes[layer_name] = [ + dtype_to_bytes(attn_module.impl.dtype), + dtype_to_bytes(self.vllm_config.model_config.dtype), + ] elif spec := attn_module.get_kv_cache_spec(self.vllm_config): assert isinstance(spec, MLAAttentionSpec) from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec From 9467161c6a261ce1431f3699664d447517047292 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Sat, 14 Mar 2026 17:26:21 +0800 Subject: [PATCH 03/13] refactor Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/attention/mla_v1.py | 138 +++++++----------- vllm_ascend/attention/utils.py | 5 +- .../kv_p2p/mooncake_layerwise_connector.py | 2 +- .../patch/worker/patch_weight_utils.py | 3 + vllm_ascend/quantization/methods/__init__.py | 6 +- vllm_ascend/quantization/modelslim_config.py | 12 +- vllm_ascend/utils.py | 8 + vllm_ascend/worker/model_runner_v1.py | 61 +++----- 8 files changed, 97 insertions(+), 138 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 47154468bb6..2506c007084 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -791,54 +791,36 @@ def update_graph_params( seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) - if dequant_scale_q_nope is None: - torch_npu.npu_fused_infer_attention_score.out( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_heads=num_heads, - num_key_value_heads=num_kv_heads, - input_layout=input_layout, - atten_mask=attn_mask, - sparse_mode=sparse_mode, - scale=scale, - antiquant_mode=0, - antiquant_scale=None, - block_table=block_table, - block_size=block_size, - actual_seq_lengths_kv=seq_lens_list, - actual_seq_lengths=actual_seq_lengths, - workspace=graph_params.workspaces.get(num_tokens), - out=[attn_output, softmax_lse], - ) - else: - torch_npu.npu_fused_infer_attention_score_v2.out( - q_nope, - k_nope, - k_nope, - query_rope=q_pe, - key_rope=k_pe, - num_query_heads=num_heads, - num_key_value_heads=num_kv_heads, - input_layout=input_layout, - atten_mask=attn_mask, - sparse_mode=sparse_mode, - softmax_scale=scale, - query_quant_mode=3, - key_quant_mode=0, - value_quant_mode=0, - dequant_scale_query=dequant_scale_q_nope, - dequant_scale_key=fak_descale_float, - dequant_scale_value=fak_descale_float, - block_table=block_table, - block_size=block_size, - actual_seq_kvlen=seq_lens_list, - actual_seq_qlen=actual_seq_lengths, - workspace=graph_params.workspaces.get(num_tokens), - out=[attn_output, softmax_lse], - ) + extra_args = {} + if dequant_scale_q_nope is not None: + extra_args = { + "query_quant_mode": 3, + "key_quant_mode": 0, + "value_quant_mode": 0, + "dequant_scale_query": dequant_scale_q_nope, + "dequant_scale_key": fak_descale_float, + "dequant_scale_value": fak_descale_float, + } + torch_npu.npu_fused_infer_attention_score_v2.out( + q_nope, + k_nope, + k_nope, + query_rope=q_pe, + key_rope=k_pe, + num_query_heads=num_heads, + num_key_value_heads=num_kv_heads, + input_layout=input_layout, + atten_mask=attn_mask, + sparse_mode=sparse_mode, + softmax_scale=scale, + block_table=block_table, + block_size=block_size, + actual_seq_kvlen=seq_lens_list, + actual_seq_qlen=actual_seq_lengths, + workspace=graph_params.workspaces.get(num_tokens), + out=[attn_output, softmax_lse], + **extra_args, + ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -1287,26 +1269,6 @@ def _forward_decode( dequant_scale_q_nope = dequant_scale_q_nope.view(num_tokens, 1, self.num_heads) sparse_mode = 0 actual_seq_lengths = None - common_kwargs_v2 = { - "query_rope": q_pe, - "key_rope": k_pe, - "num_query_heads": self.num_heads, - "num_key_value_heads": self.num_kv_heads, - "input_layout": input_layout, - "atten_mask": attn_mask, - "sparse_mode": sparse_mode, - "softmax_scale": self.scale, - "query_quant_mode": 3, - "key_quant_mode": 0, - "value_quant_mode": 0, - "dequant_scale_query": dequant_scale_q_nope, - "dequant_scale_key": self.fak_descale_float, - "dequant_scale_value": self.fak_descale_float, - "block_table": decode_meta.block_table, - "block_size": block_size, - "actual_seq_qlen": actual_seq_lengths, - "actual_seq_kvlen": decode_meta.seq_lens_list, - } attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank) else: # The output layout is set to NBSD to eliminate the need for a @@ -1329,19 +1291,27 @@ def _forward_decode( common_kwargs = { "query_rope": q_pe, "key_rope": k_pe, - "num_heads": self.num_heads, + "num_query_heads": self.num_heads, "num_key_value_heads": self.num_kv_heads, "input_layout": input_layout, "atten_mask": attn_mask, "sparse_mode": sparse_mode, - "scale": self.scale, - "antiquant_mode": 0, - "antiquant_scale": None, + "softmax_scale": self.scale, "block_table": decode_meta.block_table, "block_size": block_size, - "actual_seq_lengths": actual_seq_lengths, - "actual_seq_lengths_kv": decode_meta.seq_lens_list, + "actual_seq_qlen": actual_seq_lengths, + "actual_seq_kvlen": decode_meta.seq_lens_list, } + if self.fa_quant_layer: + extra_fa_args = { + "query_quant_mode": 3, + "key_quant_mode": 0, + "value_quant_mode": 0, + "dequant_scale_query": dequant_scale_q_nope, + "dequant_scale_key": self.fak_descale_float, + "dequant_scale_value": self.fak_descale_float, + } + common_kwargs.update(extra_fa_args) if _EXTRA_CTX.is_draft_model: graph_params = get_draft_graph_params() else: @@ -1376,18 +1346,15 @@ def _forward_decode( weak_ref_tensors(softmax_lse), ) if self.fa_quant_layer: - get_max_workspace_func = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace - fused_infer_attention_func = torch_npu.npu_fused_infer_attention_score_v2.out attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) - common_kwargs = common_kwargs_v2 else: - get_max_workspace_func = torch_npu._npu_fused_infer_attention_score_get_max_workspace - fused_infer_attention_func = torch_npu.npu_fused_infer_attention_score.out attn_params = attn_params + (None, None) if workspace is None: - workspace = get_max_workspace_func(q_nope, k_nope, k_nope, **common_kwargs) - if forward_context.is_draft_model: + workspace = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace( + q_nope, k_nope, k_nope, **common_kwargs + ) + if _EXTRA_CTX.is_draft_model: update_draft_graph_params_workspaces(num_tokens, workspace) else: update_graph_params_workspaces(num_tokens, workspace) @@ -1395,15 +1362,14 @@ def _forward_decode( graph_params.attn_params[num_tokens].append(attn_params) torch.npu.graph_task_group_begin(stream) - fused_infer_attention_func( + torch_npu.npu_fused_infer_attention_score_v2.out( q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse] ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) - elif self.fa_quant_layer: - attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs_v2) else: - attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs) + attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs) + return self._v_up_proj(attn_output) def reorg_decode_q(self, decode_q_nope, decode_q_pe): @@ -1659,6 +1625,7 @@ def forward( decode_preprocess_res.k_pe, kv_cache[0].shape[1], attn_metadata, + decode_preprocess_res.dequant_scale_q_nope, ) o_proj_input[:num_decode_tokens] = output_decode @@ -1675,7 +1642,6 @@ def forward( prefill_preprocess_res.value, kv_cache, attn_metadata, - decode_preprocess_res.dequant_scale_q_nope, ) o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index f4c19727a88..07428081646 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,3 +1,4 @@ +import re from dataclasses import dataclass, field from functools import lru_cache from typing import Any @@ -343,7 +344,7 @@ def enabling_fa_quant(vllm_config: VllmConfig, layer_name) -> bool: enable_fa_quant = quant_config.enable_fa_quant if quant_config is not None else False fa_quant_layer = False if is_decode_instance and enable_fa_quant: - id = "".join(re.findall(r"\.(\d+)\.", layer_name)) - if int(id) in quant_config.kvcache_quant_layers: + layer_id_str = "".join(re.findall(r"\.(\d+)\.", layer_name)) + if layer_id_str.isdigit() and int(layer_id_str) in quant_config.kvcache_quant_layers: fa_quant_layer = True return fa_quant_layer diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 74fce722373..091cdb32593 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -334,7 +334,7 @@ def get_transfer_meta(self, send_task: SendTask, req_id: str, req_meta: ReqMeta, grouped_remote_block_ids, grouped_local_block_ids = group_concurrent_contiguous( remote_block_ids, local_block_ids ) -# kv cache quantization scenario + # kv cache quantization scenario if self.enable_kv_quant and send_task.k_quant_cache is not None: assert len(block_lens) == 2, "Quantization block length must be 2!" quant_block_lens = [block_lens[0] // 2, block_lens[1]] diff --git a/vllm_ascend/patch/worker/patch_weight_utils.py b/vllm_ascend/patch/worker/patch_weight_utils.py index 26fddd04344..63d7a1ad802 100644 --- a/vllm_ascend/patch/worker/patch_weight_utils.py +++ b/vllm_ascend/patch/worker/patch_weight_utils.py @@ -1,7 +1,10 @@ import sys +from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name +logger = init_logger(__name__) + class ImportPatchDecorator: """Import patch decorator""" diff --git a/vllm_ascend/quantization/methods/__init__.py b/vllm_ascend/quantization/methods/__init__.py index 84b6d773285..2e91ba2dcde 100644 --- a/vllm_ascend/quantization/methods/__init__.py +++ b/vllm_ascend/quantization/methods/__init__.py @@ -32,11 +32,11 @@ # Import base classes from .base import AscendAttentionScheme, AscendLinearScheme, AscendMoEScheme, QuantType -# Import registry functions -from .registry import get_scheme_class, register_scheme - # Import all scheme classes for external access from .kv_c8 import AscendFAQuantAttentionMethod + +# Import registry functions +from .registry import get_scheme_class, register_scheme from .w4a4_flatquant import AscendW4A4FlatQuantDynamicLinearMethod from .w4a4_laos_dynamic import AscendW4A4LaosDynamicLinearMethod from .w4a8 import AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 4c3bc19a4a4..196826e79bf 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -39,9 +39,8 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase from vllm.model_executor.layers.vocab_parallel_embedding import UnquantizedEmbeddingMethod, VocabParallelEmbedding from vllm.model_executor.models.utils import WeightsMapper -from vllm.utils.torch_utils import get_dtype_size -from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD +from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD, calc_split_factor from .methods import get_scheme_class @@ -512,8 +511,6 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["Qua self.packed_modules_mapping = packed_modules_model_mapping[model_type] prefix = self.quant_prefix_mapper(model_type, prefix) - from vllm.model_executor.layers.attention import Attention - if model_type != "kimi_k2": if prefix.startswith("language_model"): prefix = prefix.split(".", 1)[-1] @@ -590,6 +587,13 @@ def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config): return quant_dtype, quant_dtype return cache_dtype, cache_dtype + def get_kv_quant_split_factor(self, layer_name, kv_head_dim_list): + if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): + k_quant_head_dim = kv_head_dim_list[0] + v_quant_head_dim = kv_head_dim_list[1] * 2 + kv_head_dim_list = [k_quant_head_dim, v_quant_head_dim] + return calc_split_factor(kv_head_dim_list) + def maybe_update_config(self, model_name: str, revision: str | None = None) -> None: """Load the ModelSlim quantization config from model directory. diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 66d94475e3c..79b7817504a 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1211,3 +1211,11 @@ def get_rope_dim(vllm_config): rope_dim = int(model_config.hf_text_config.rotary_dim) return rope_dim + + +def calc_split_factor(num_list: list[int]): + total = sum(num_list) + split_factor_list = [] + for num in num_list: + split_factor_list.append(total / num) + return split_factor_list diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 695beffc157..08e0d059ad9 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -114,6 +114,7 @@ from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.utils import ( + calc_split_factor, check_gdn_layer, enable_sp, enable_sp_by_pass, @@ -407,7 +408,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.query_lens: torch.Tensor | None = None self.cpu_slot_mapping = None self.sampling_done_event: torch.npu.Event | None = None - self.kvbytes = {} if vllm_version_is("0.17.0"): # self.cudagraph_batch_sizes sorts in ascending order. @@ -2661,18 +2661,6 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if self.model_config.use_mla: - if layer_name in self.kvbytes: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim * self.kvbytes[layer_name][1] - + self.model_config.hf_text_config.kv_lora_rank * self.kvbytes[layer_name][0] - ) - else: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim - + self.model_config.hf_text_config.kv_lora_rank - ) - if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec k_tensor_split_factor = 2.0 @@ -2685,17 +2673,16 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str v_tensor_split_factor = sparse_kv_cache_ratio[1] dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] - elif layer_name in self.kvbytes: - k_tensor_split_factor = head_size / ( - self.model_config.hf_text_config.kv_lora_rank * self.kvbytes[layer_name][0] - ) - v_tensor_split_factor = head_size / ( - self.model_config.hf_text_config.qk_rope_head_dim * self.kvbytes[layer_name][1] - ) else: - # for other deepseek models, use MLAAttentionSpec - k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank - v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim + kv_head_dim_list = [ + self.model_config.hf_text_config.kv_lora_rank, + self.model_config.hf_text_config.qk_rope_head_dim + ] + if self.is_kv_consumer and self.vllm_config.quant_config is not None: + k_tensor_split_factor, v_tensor_split_factor = (self.vllm_config.quant_config. + get_kv_quant_split_factor(layer_name, kv_head_dim_list)) + else: + k_tensor_split_factor, v_tensor_split_factor = calc_split_factor(kv_head_dim_list) k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) @@ -2873,7 +2860,7 @@ def _reshape_kv_cache_tensors( self.model_config.hf_text_config.qk_rope_head_dim, ] k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype - if self.vllm_config.quant_config is not None: + if self.is_kv_consumer and self.vllm_config.quant_config is not None: k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( layer_name, kv_cache_spec.dtype, self.model_config ) @@ -3176,30 +3163,20 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: cache_dtype_str=self.vllm_config.cache_config.cache_dtype, cache_sparse_c8=self.use_sparse_c8_indexer, ) - elif getattr(attn_module.impl, "fa_quant_layer", False): - block_size = self.vllm_config.cache_config.block_size - head_size = attn_module.head_size + attn_module.qk_rope_head_dim - kv_cache_spec[layer_name] = AscendMLAAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=head_size, - dtype=attn_module.impl.dtype, - cache_dtype_str=None, - ) - if layer_name not in self.kvbytes: - self.kvbytes[layer_name] = [ - dtype_to_bytes(attn_module.impl.dtype), - dtype_to_bytes(self.vllm_config.model_config.dtype), - ] elif spec := attn_module.get_kv_cache_spec(self.vllm_config): assert isinstance(spec, MLAAttentionSpec) from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec + if getattr(attn_module.impl, "fa_quant_layer", False): + head_size = attn_module.head_size + attn_module.qk_rope_head_dim + dtype, cache_dtype_str = attn_module.impl.dtype, None + else: + head_size, dtype, cache_dtype_str = spec.head_size, spec.dtype, spec.cache_dtype_str kv_cache_spec[layer_name] = AscendMLAAttentionSpec( block_size=spec.block_size, num_kv_heads=spec.num_kv_heads, - head_size=spec.head_size, - dtype=spec.dtype, - cache_dtype_str=spec.cache_dtype_str, + head_size=head_size, + dtype=dtype, + cache_dtype_str=cache_dtype_str, ) elif isinstance(attn_module, MambaBase): From 670bb8903cc1736b5ff5a87cf0784bc99983d0f3 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Sat, 14 Mar 2026 18:08:01 +0800 Subject: [PATCH 04/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/attention/context_parallel/mla_cp.py | 1 + vllm_ascend/attention/mla_v1.py | 4 ++-- vllm_ascend/patch/worker/patch_weight_utils.py | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/context_parallel/mla_cp.py b/vllm_ascend/attention/context_parallel/mla_cp.py index 2ea6c5298d0..aa9a0e0f8e0 100644 --- a/vllm_ascend/attention/context_parallel/mla_cp.py +++ b/vllm_ascend/attention/context_parallel/mla_cp.py @@ -602,6 +602,7 @@ def _forward_decode( k_pe: torch.Tensor, block_size: int, attn_metadata: AscendMLAMetadata, + dequant_scale_q_nope=None, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 2506c007084..e8b75985306 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1346,9 +1346,9 @@ def _forward_decode( weak_ref_tensors(softmax_lse), ) if self.fa_quant_layer: - attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) + attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) # type: ignore else: - attn_params = attn_params + (None, None) + attn_params = attn_params + (None, None) # type: ignore if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace( diff --git a/vllm_ascend/patch/worker/patch_weight_utils.py b/vllm_ascend/patch/worker/patch_weight_utils.py index 63d7a1ad802..f2e50c43af6 100644 --- a/vllm_ascend/patch/worker/patch_weight_utils.py +++ b/vllm_ascend/patch/worker/patch_weight_utils.py @@ -1,4 +1,5 @@ import sys +from typing import Any from vllm.logger import init_logger from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name @@ -9,7 +10,7 @@ class ImportPatchDecorator: """Import patch decorator""" - _patches = {} + _patches: dict[str, Any] = {} @classmethod def register(cls, module_name): @@ -65,7 +66,7 @@ def patch_weight_utils(module): module.maybe_remap_kv_scale_name = deepseek.maybe_remap_kv_scale_name -original_import = __builtins__["__import__"] +original_import = __builtins__["__import__"] # type: ignore def patched_import(name, globals=None, locals=None, fromlist=(), level=0): From 7cd6ddf032da991f671fb45232a00931b33be161 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Sat, 14 Mar 2026 20:19:09 +0800 Subject: [PATCH 05/13] refactor Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/attention/mla_v1.py | 5 ++- vllm_ascend/attention/utils.py | 17 ---------- .../kv_p2p/mooncake_layerwise_connector.py | 31 +++++-------------- vllm_ascend/quantization/modelslim_config.py | 12 +++++-- vllm_ascend/utils.py | 25 +++++++++++++++ 5 files changed, 46 insertions(+), 44 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e8b75985306..261dad01586 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -719,7 +719,10 @@ def __init__( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) self.layer_name = kwargs.get("layer_name") - self.fa_quant_layer = enabling_fa_quant(self.vllm_config, self.layer_name) + quant_config = self.vllm_config.quant_config + self.fa_quant_layer = ( + quant_config.enabling_fa_quant(self.vllm_config, self.layer_name) if quant_config is not None else False + ) self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype self.layer_sharding_kwargs = [] for layer_name in get_ascend_config().layer_sharding or []: diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 07428081646..946d5c66d4e 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,4 +1,3 @@ -import re from dataclasses import dataclass, field from functools import lru_cache from typing import Any @@ -332,19 +331,3 @@ def enabling_mlapo(vllm_config: VllmConfig) -> bool: and not vllm_config.kv_transfer_config.is_kv_producer ) return bool(envs.VLLM_ASCEND_ENABLE_MLAPO and is_decode_instance) - - -def enabling_fa_quant(vllm_config: VllmConfig, layer_name) -> bool: - is_decode_instance = ( - vllm_config.kv_transfer_config is not None - and vllm_config.kv_transfer_config.is_kv_consumer - and not vllm_config.kv_transfer_config.is_kv_producer - ) - quant_config = vllm_config.quant_config - enable_fa_quant = quant_config.enable_fa_quant if quant_config is not None else False - fa_quant_layer = False - if is_decode_instance and enable_fa_quant: - layer_id_str = "".join(re.findall(r"\.(\d+)\.", layer_name)) - if layer_id_str.isdigit() and int(layer_id_str) in quant_config.kvcache_quant_layers: - fa_quant_layer = True - return fa_quant_layer diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 091cdb32593..b91e6d9bbeb 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -66,7 +66,7 @@ kv_alltoall_and_rearrange, parallel_info, ) -from vllm_ascend.utils import npu_stream_switch +from vllm_ascend.utils import npu_stream_switch, trans_nd_to_nz # isort: off if TYPE_CHECKING: @@ -1594,8 +1594,8 @@ def save_kv_layer( quant_keys = torch.ops.vllm.quantize( quant_key, layer.fak_descale, layer.fak_descale_reciprocal, layer.fak_offset ) - quant_keys = self.trans_nd_to_nz(quant_keys, layer_group_idx) - quant_values = self.trans_nd_to_nz(quant_values, layer_group_idx) + quant_keys = self.get_nz_cache(quant_keys, layer_group_idx) + quant_values = self.get_nz_cache(quant_values, layer_group_idx) assert self.kv_send_layer_thread is not None assert reshape_cache_event is not None @@ -1619,29 +1619,12 @@ def save_kv_layer( self.kv_send_layer_thread.send_queue.put(layer_send_task) self.current_layer += 1 - def trans_nd_to_nz(self, cache_tensor: torch.Tensor, layer_group_idx: int): + # NOTE: Due to the FIA operator constraints, the expected kv cache is ND format, NZ shape, + # while the npu_format_cast method only modifies the memory layout, we manually convert it to NZ shape here + def get_nz_cache(self, cache_tensor: torch.Tensor, layer_group_idx: int): head_num, head_dim = cache_tensor.shape[-2], cache_tensor.shape[-1] cache_tensor = cache_tensor.view(-1, self.block_size[layer_group_idx], head_num * head_dim) - - batch = cache_tensor.shape[:-2] - a, b = cache_tensor.shape[-2], cache_tensor.shape[-1] - - dtype = cache_tensor.dtype - if dtype == torch.int8: - a0, b0 = 16, 32 - else: - a0, b0 = 16, 16 - - nz_shape = list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0] - - # Generate the axis order for the transpose operation. - offset = len(cache_tensor.shape) - 2 - base = [2, 0, 1, 3] - array_trans = [i for i in range(offset)] + [i + offset for i in base] - # Perform shape transformation and transpose operation. - *_, n1, m1, m0, n0 = nz_shape - cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) - cache_tensor = cache_tensor.permute(*array_trans) + cache_tensor = trans_nd_to_nz(cache_tensor) cache_tensor = cache_tensor.reshape(-1, head_num, head_dim) return cache_tensor diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 196826e79bf..93b4d3ae7ec 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -571,11 +571,19 @@ def is_layer_skipped_ascend(self, prefix: str, fused_mapping: Mapping[str, list[ def is_fa_quant_layer(self, prefix): if self.enable_fa_quant: - _id = int("".join(re.findall(r"\.(\d+)\.", prefix))) - if _id in self.kvcache_quant_layers: + layer_id_str = "".join(re.findall(r"\.(\d+)\.", prefix)) + if layer_id_str.isdigit() and int(layer_id_str) in self.kvcache_quant_layers: return True return False + def enabling_fa_quant(self, vllm_config, layer_name) -> bool: + is_decode_instance = ( + vllm_config.kv_transfer_config is not None + and vllm_config.kv_transfer_config.is_kv_consumer + and not vllm_config.kv_transfer_config.is_kv_producer + ) + return bool(is_decode_instance and self.is_fa_quant_layer(layer_name)) + def get_kv_quant_dtype(self, layer_name, cache_dtype, model_config): if self.enable_fa_quant and self.is_fa_quant_layer(layer_name): ori_dtype = model_config.dtype diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 79b7817504a..458b95cdec4 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1219,3 +1219,28 @@ def calc_split_factor(num_list: list[int]): for num in num_list: split_factor_list.append(total / num) return split_factor_list + + +# NOTE: The last two dimensions of ND are transferred to NZ +def trans_nd_to_nz(cache_tensor: torch.Tensor): + assert len(cache_tensor.shape) >= 2 + batch = cache_tensor.shape[:-2] + a, b = cache_tensor.shape[-2], cache_tensor.shape[-1] + + dtype = cache_tensor.dtype + if dtype == torch.int8: + a0, b0 = 16, 32 + else: + a0, b0 = 16, 16 + + nz_shape = list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0] + + # Generate the axis order for the transpose operation. + offset = len(cache_tensor.shape) - 2 + base = [2, 0, 1, 3] + array_trans = [i for i in range(offset)] + [i + offset for i in base] + # Perform shape transformation and transpose operation. + *_, n1, m1, m0, n0 = nz_shape + cache_tensor = cache_tensor.reshape(nz_shape[:-4] + [m1, m0, n1, n0]) + cache_tensor = cache_tensor.permute(*array_trans) + return cache_tensor From 3c0778ebee66035a5d54092d15ca3dd4c97e0830 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Sat, 14 Mar 2026 20:36:57 +0800 Subject: [PATCH 06/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/attention/mla_v1.py | 1 - .../kv_transfer/kv_p2p/mooncake_layerwise_connector.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 261dad01586..a3ac404f120 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -27,7 +27,6 @@ AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, enable_cp, - enabling_fa_quant, enabling_mlapo, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index b91e6d9bbeb..ea82322b00b 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -1031,6 +1031,7 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig, engi self.timeout = 1.0 # seconds self.k_buffer: torch.Tensor | None = None self.v_buffer: torch.Tensor | None = None + # TODO(kunpengW-code): Reuse k_buffer, v_buffer self.k_quant_buffer: torch.Tensor | None = None self.v_quant_buffer: torch.Tensor | None = None From 89ef7bfd174ccf6322d5fe6bf1acc2f46dda1f34 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Sat, 14 Mar 2026 21:33:45 +0800 Subject: [PATCH 07/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- tests/ut/attention/test_mla_v1.py | 21 ++++++++++--------- .../ut/quantization/test_modelslim_config.py | 4 ++++ vllm_ascend/patch/__init__.py | 14 +++++++++++++ vllm_ascend/quantization/modelslim_config.py | 1 + 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 8fe78566bfc..ddfd2101951 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -806,6 +806,7 @@ def setUp(self, get_current_vllm_config, mock_tp): attn_type=None, kv_sharing_target_layer_name=None, **kwargs) + self.impl.fa_quant_layer = False def test_init(self): self.assertEqual(self.impl.num_heads, 256) @@ -931,9 +932,9 @@ def test_compute_prefill_context(self, mock_ring, mock_load): @patch('vllm_ascend.ascend_forward_context.get_forward_context') @patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj") - @patch("torch_npu.npu_fused_infer_attention_score") + @patch("torch_npu.npu_fused_infer_attention_score_v2") def test_forward_decode_without_graph(self, - mock_npu_fused_infer_attention_score, + mock_npu_fused_infer_attention_score_v2, mock_up_proj, mock_get_forward_context): num_tokens = 100 @@ -949,8 +950,8 @@ def test_forward_decode_without_graph(self, metadata = MagicMock() metadata.decode = MagicMock() metadata.decode.block_table = MagicMock() - metadata.decode.seq_lens = 10 - mock_npu_fused_infer_attention_score.return_value = [ + metadata.decode.actual_seq_lengths = 10 + mock_npu_fused_infer_attention_score_v2.return_value = [ torch.randn(num_tokens, self.impl.num_heads, self.impl.kv_lora_rank), None ] @@ -964,7 +965,7 @@ def test_forward_decode_without_graph(self, self.assertEqual(result.shape[1], self.impl.num_heads) self.assertEqual(result.shape[2], self.impl.v_head_dim) mock_up_proj.assert_called_once() - mock_npu_fused_infer_attention_score.assert_called_once() + mock_npu_fused_infer_attention_score_v2.assert_called_once() @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") @patch("vllm_ascend.attention.mla_v1.get_weight_prefetch_method", @@ -1096,8 +1097,8 @@ def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache): self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank) @patch('vllm_ascend.ascend_forward_context.get_forward_context') - @patch("torch_npu.npu_fused_infer_attention_score") - def test_forward_decode(self, mock_npu_fused_infer_attention_score, + @patch("torch_npu.npu_fused_infer_attention_score_v2") + def test_forward_decode(self, mock_npu_fused_infer_attention_score_v2, mock_get_forward_context): B = 2 N = self.impl.num_kv_heads @@ -1114,11 +1115,11 @@ def test_forward_decode(self, mock_npu_fused_infer_attention_score, attn_metadata = MagicMock() attn_metadata.attn_state = AscendAttentionState.SpecDecoding attn_metadata.decode = MagicMock() - attn_metadata.decode.actual_seq_lengths_q = MagicMock() - attn_metadata.decode.seq_lens_list = MagicMock() + attn_metadata.decode.actual_seq_qlen = MagicMock() + attn_metadata.decode.actual_seq_kvlen = MagicMock() self.impl.enable_kv_nz = True - mock_npu_fused_infer_attention_score.return_value = [ + mock_npu_fused_infer_attention_score_v2.return_value = [ torch.randn(B, N, self.impl.kv_lora_rank), None ] mock_get_forward_context.return_value = MagicMock(capturing=False) diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index 556c8a4acd3..745b4736264 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -24,6 +24,7 @@ def setUp(self): self.sample_config = { "weight": "INT8", "fa_quant_type": "C8", + "layers.1.fa_k.scale": "C8", "layer1.weight": "INT8", "layer2.weight": "FLOAT", "fused_layer.weight": "FLOAT", @@ -119,6 +120,9 @@ def test_get_quant_method_for_attention(self): # Test with fa_quant_type method = self.ascend_config.get_quant_method( attention_layer, ".attn") + self.assertIs(method, None) + method = self.ascend_config.get_quant_method( + attention_layer, "layers.1.attn") self.assertIs(method, mock_ascend_kvcache.return_value) def test_get_quant_method_for_fused_moe(self): diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 4eb74013467..1163e77b064 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -506,3 +506,17 @@ # Rotary quant is a unique feature of vllm-ascend. # Future Plan: # Remove this patch when vllm supports rotary quant or pluggable `MultiTokenPredictorLayer`. +# +# ** 22. File: worker/patch_weight_utils.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.load_weights` +# Why: +# The C8 weight quantized by modelslim will modify the model structure, +# and the scale and offset required for kvcache quantization will increase. +# In addition, the names of the quantization parameters are different from +# those in the community. +# How: +# we have enhanced the maybe_remap_kv_scale_name function. +# Future Plan: +# The maybe_remap_kv_scale_name function of the community is reconstructed to support +# multiple backends. diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index 93b4d3ae7ec..82b3d279b94 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -440,6 +440,7 @@ def __init__(self, quant_config: dict[str, Any] | None = None): new_k = k.replace("weight_packed", "weight") extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) + self._add_kvcache_quant_metadata() def __repr__(self) -> str: return "AscendModelSlimConfig:\n" + super().__repr__() From aa1e27e6d1778b233c07f6dca3c5d9a1d0b7b63b Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Sat, 14 Mar 2026 21:48:09 +0800 Subject: [PATCH 08/13] add ut Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- tests/ut/quantization/test_kv_c8.py | 479 ++++++++++++++++++++++++++++ 1 file changed, 479 insertions(+) create mode 100644 tests/ut/quantization/test_kv_c8.py diff --git a/tests/ut/quantization/test_kv_c8.py b/tests/ut/quantization/test_kv_c8.py new file mode 100644 index 00000000000..e29a0c5c630 --- /dev/null +++ b/tests/ut/quantization/test_kv_c8.py @@ -0,0 +1,479 @@ +import unittest +import torch +import torch.nn as nn +from unittest.mock import Mock, patch + + +class TestWeightLoader(unittest.TestCase): + """Test cases for weight_loader function in kv_c8.py""" + + def setUp(self): + """Set up test environment before each test""" + # Import the module under test + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + self.weight_loader = weight_loader + + # Mock distributed functions + self.tp_rank_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank" + ) + self.tp_size_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size" + ) + self.mock_tp_rank = self.tp_rank_patch.start() + self.mock_tp_size = self.tp_size_patch.start() + + def tearDown(self): + """Clean up after each test""" + self.tp_rank_patch.stop() + self.tp_size_patch.stop() + + def test_weight_loader_single_element(self): + """Test weight_loader when both tensors contain a single element""" + # Create tensors with single element + param = torch.tensor([0.0]) + loaded_weight = torch.tensor([5.0]) + + # Call weight_loader + self.weight_loader(param, loaded_weight) + + # Verify the value was filled correctly + self.assertEqual(param.item(), 5.0) + self.assertEqual(param.dtype, torch.float32) + + def test_weight_loader_single_element_int(self): + """Test weight_loader with integer tensors""" + param = torch.tensor([0], dtype=torch.int32) + loaded_weight = torch.tensor([10], dtype=torch.int32) + + self.weight_loader(param, loaded_weight) + + self.assertEqual(param.item(), 10) + + def test_weight_loader_tp_sharding_first_rank(self): + """Test weight_loader with tensor parallelism sharding for first rank""" + # Configure mocks for rank 0 of 4 + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 4 + + # Create test tensors + param = torch.zeros(2, 5) # Target param shape (2,5) + loaded_weight = torch.ones(8, 5) # Full weight (8,5) + + # Mock narrow to track the call + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: narrow(dim=0, start=0, length=2) + mock_narrow.assert_called_once_with(0, 0, 2) + + # Verify data was copied + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_tp_sharding_middle_rank(self): + """Test weight_loader with tensor parallelism sharding for middle rank""" + # Configure mocks for rank 2 of 4 + self.mock_tp_rank.return_value = 2 + self.mock_tp_size.return_value = 4 + + param = torch.zeros(2, 5) + loaded_weight = torch.ones(8, 5) + + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: start = shard_size * rank = 2 * 2 = 4 + mock_narrow.assert_called_once_with(0, 4, 2) + + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_tp_sharding_last_rank(self): + """Test weight_loader with tensor parallelism sharding for last rank""" + # Configure mocks for rank 3 of 4 + self.mock_tp_rank.return_value = 3 + self.mock_tp_size.return_value = 4 + + param = torch.zeros(2, 5) + loaded_weight = torch.ones(8, 5) + + with patch.object(loaded_weight, 'narrow', wraps=loaded_weight.narrow) as mock_narrow: + self.weight_loader(param, loaded_weight) + + # Verify narrow was called correctly: start = 2 * 3 = 6 + mock_narrow.assert_called_once_with(0, 6, 2) + + self.assertTrue(torch.all(param == 1)) + + def test_weight_loader_shape_mismatch(self): + """Test weight_loader raises assertion error on shape mismatch""" + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 2 + + param = torch.zeros(2, 3) + loaded_weight = torch.ones(4, 4) # Different shape after sharding + + # Mock narrow to return tensor with wrong shape + with patch.object(loaded_weight, 'narrow', return_value=torch.ones(2, 4)): + with self.assertRaises(AssertionError) as context: + self.weight_loader(param, loaded_weight) + + # Verify error message contains expected information + self.assertIn("Attempted to load weight", str(context.exception)) + self.assertIn("into parameter", str(context.exception)) + + def test_weight_loader_with_different_dtypes(self): + """Test weight_loader handles different dtypes correctly""" + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 1 # No sharding + + param = torch.zeros(2, 3, dtype=torch.float32) + loaded_weight = torch.ones(2, 3, dtype=torch.float16) + + self.weight_loader(param, loaded_weight) + + # Verify data was copied and converted + self.assertTrue(torch.all(param == 1)) + self.assertEqual(param.dtype, torch.float32) + + +class TestAscendFAQuantAttentionMethodInit(unittest.TestCase): + """Test cases for AscendFAQuantAttentionMethod initialization""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + # Create mock config with attributes + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 128 + self.mock_hf_config.qk_rope_head_dim = 64 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class after patching + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + + def test_init_with_full_config(self): + """Test initialization when config has all attributes""" + method = self.method_class() + + self.assertTrue(method.transpose_weight) + self.assertFalse(method.printFlag) + self.assertEqual(method.kv_lora_rank, 128) + self.assertEqual(method.qk_rope_head_dim, 64) + + def test_init_without_kv_lora_rank(self): + """Test initialization when config lacks kv_lora_rank""" + delattr(self.mock_hf_config, "kv_lora_rank") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 0) + self.assertEqual(method.qk_rope_head_dim, 64) + + def test_init_without_qk_rope_head_dim(self): + """Test initialization when config lacks qk_rope_head_dim""" + delattr(self.mock_hf_config, "qk_rope_head_dim") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 128) + self.assertEqual(method.qk_rope_head_dim, 0) + + def test_init_without_both_attributes(self): + """Test initialization when config lacks both attributes""" + delattr(self.mock_hf_config, "kv_lora_rank") + delattr(self.mock_hf_config, "qk_rope_head_dim") + + method = self.method_class() + + self.assertEqual(method.kv_lora_rank, 0) + self.assertEqual(method.qk_rope_head_dim, 0) + + +class TestAscendFAQuantAttentionMethodCreateWeights(unittest.TestCase): + """Test cases for create_weights method""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 128 + self.mock_hf_config.qk_rope_head_dim = 64 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + # Mock torch functions + self.default_dtype_patch = patch("torch.get_default_dtype", return_value=torch.float32) + self.mock_default_dtype = self.default_dtype_patch.start() + + # Create a real nn.Module for testing + self.layer = nn.Module() + self.layer.num_heads = 32 + self.layer.num_kv_heads = 8 + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + self.default_dtype_patch.stop() + + def test_create_weights_adds_submodules(self): + """Test that create_weights adds fa_q, fa_k, fa_v submodules""" + method = self.method_class() + + with patch("torch.empty") as mock_empty: + mock_empty.return_value = torch.zeros(1, 1) + + method.create_weights(self.layer) + + # Verify submodules were added + self.assertTrue(hasattr(self.layer, "fa_q")) + self.assertTrue(hasattr(self.layer, "fa_k")) + self.assertTrue(hasattr(self.layer, "fa_v")) + + # Verify they are instances of nn.Module + self.assertIsInstance(self.layer.fa_q, nn.Module) + self.assertIsInstance(self.layer.fa_k, nn.Module) + self.assertIsInstance(self.layer.fa_v, nn.Module) + + def test_create_weights_sets_num_kv_heads_to_one(self): + """Test that num_kv_heads is set to 1""" + method = self.method_class() + + with patch("torch.empty") as mock_empty: + mock_empty.return_value = torch.zeros(1, 1) + + method.create_weights(self.layer) + + self.assertEqual(self.layer.num_kv_heads, 1) + + def test_create_weights_creates_correct_tensors(self): + """Test that create_weights creates tensors with correct shapes and dtypes""" + method = self.method_class() + + # Track torch.empty calls + empty_calls = [] + + def mock_empty(size, dtype=None): + empty_calls.append((size, dtype)) + return torch.zeros(size, dtype=dtype if dtype else torch.float32) + + with patch("torch.empty", side_effect=mock_empty): + method.create_weights(self.layer) + + # Verify tensor creations + expected_calls = [ + ((32, 1), torch.float32), # fa_q.scale + ((1, 1), torch.float32), # fa_k.scale + ((1, 1), torch.float32), # fa_v.scale + ((32, 1), torch.int8), # fa_q.offset + ((1, 1), torch.int8), # fa_k.offset + ((1, 1), torch.int8), # fa_v.offset + ] + + # Compare without considering order + self.assertEqual(len(empty_calls), len(expected_calls)) + for call in expected_calls: + self.assertIn(call, empty_calls) + + def test_create_weights_registers_parameters(self): + """Test that create_weights registers parameters with correct attributes""" + method = self.method_class() + + # Create real tensors for testing + def create_tensor(*args, **kwargs): + size = args[0] if args else kwargs.get('size', (1,)) + dtype = kwargs.get('dtype', torch.float32) + return torch.zeros(*size, dtype=dtype) + + with patch("torch.empty", side_effect=create_tensor): + method.create_weights(self.layer) + + # Import weight_loader for comparison + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + + # Verify each parameter exists and has weight_loader + self.assertTrue(hasattr(self.layer.fa_q, "scale")) + self.assertTrue(hasattr(self.layer.fa_q.scale, "weight_loader")) + self.assertEqual(self.layer.fa_q.scale.weight_loader, weight_loader) + self.assertFalse(self.layer.fa_q.scale.requires_grad) + + self.assertTrue(hasattr(self.layer.fa_k, "scale")) + self.assertTrue(hasattr(self.layer.fa_k.scale, "weight_loader")) + + self.assertTrue(hasattr(self.layer.fa_v, "scale")) + self.assertTrue(hasattr(self.layer.fa_v.scale, "weight_loader")) + + self.assertTrue(hasattr(self.layer.fa_q, "offset")) + self.assertTrue(hasattr(self.layer.fa_q.offset, "weight_loader")) + self.assertEqual(self.layer.fa_q.offset.dtype, torch.int8) + + +class TestAscendFAQuantAttentionMethodProcessWeights(unittest.TestCase): + """Test cases for process_weights_after_loading method""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 64 + self.mock_hf_config.qk_rope_head_dim = 32 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Import the class + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + self.method_class = AscendFAQuantAttentionMethod + + # Create method instance with real layer + self.method = self.method_class() + + # Create a real nn.Module for testing + self.layer = nn.Module() + + # Create real tensors for fa_k + self.fa_k_scale = torch.tensor([[2.0, 3.0, 4.0]], dtype=torch.float16) # Shape (1,3) + self.fa_k_offset = torch.tensor([[1, 2, 3]], dtype=torch.int8) # Shape (1,3) + + # Create fa_k module with parameters + self.layer.fa_k = nn.Module() + self.layer.fa_k.scale = nn.Parameter(self.fa_k_scale, requires_grad=False) + self.layer.fa_k.offset = nn.Parameter(self.fa_k_offset, requires_grad=False) + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + + def test_process_weights_with_single_value_scale(self): + """Test process_weights with single value scale""" + # Create new layer with single value scale + layer = nn.Module() + layer.fa_k = nn.Module() + layer.fa_k.scale = nn.Parameter(torch.tensor([[2.0]], dtype=torch.float16), requires_grad=False) + layer.fa_k.offset = nn.Parameter(torch.tensor([[1]], dtype=torch.int8), requires_grad=False) + + self.method.kv_lora_rank = 4 + self.method.process_weights_after_loading(layer) + + self.assertEqual(layer.quant_kscale.shape, (1, 4)) + self.assertEqual(layer.quant_kscale.dtype, torch.float32) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete kv_c8 functionality""" + + def setUp(self): + """Set up test environment""" + # Mock vllm_config + self.config_patch = patch("vllm_ascend.quantization.methods.kv_c8.get_current_vllm_config") + self.mock_get_config = self.config_patch.start() + + self.mock_config = Mock() + self.mock_hf_config = Mock() + self.mock_hf_config.kv_lora_rank = 64 + self.mock_hf_config.qk_rope_head_dim = 32 + self.mock_config.model_config.hf_config = self.mock_hf_config + self.mock_get_config.return_value = self.mock_config + + # Mock distributed functions + self.tp_rank_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_rank" + ) + self.tp_size_patch = patch( + "vllm_ascend.quantization.methods.kv_c8.get_tensor_model_parallel_world_size" + ) + self.mock_tp_rank = self.tp_rank_patch.start() + self.mock_tp_size = self.tp_size_patch.start() + + def tearDown(self): + """Clean up after each test""" + self.config_patch.stop() + self.tp_rank_patch.stop() + self.tp_size_patch.stop() + + def test_complete_workflow(self): + """Test complete workflow from weight creation to processing""" + from vllm_ascend.quantization.methods.kv_c8 import AscendFAQuantAttentionMethod + + # Create method instance + method = AscendFAQuantAttentionMethod() + + # Create real layer + layer = nn.Module() + layer.num_heads = 32 + layer.num_kv_heads = 8 + + # Step 1: Create weights + method.create_weights(layer) + + # Verify weights were created with correct structure + self.assertTrue(hasattr(layer, "fa_q")) + self.assertTrue(hasattr(layer, "fa_k")) + self.assertTrue(hasattr(layer, "fa_v")) + + self.assertTrue(hasattr(layer.fa_q, "scale")) + self.assertTrue(hasattr(layer.fa_q, "offset")) + self.assertTrue(hasattr(layer.fa_k, "scale")) + self.assertTrue(hasattr(layer.fa_k, "offset")) + self.assertTrue(hasattr(layer.fa_v, "scale")) + self.assertTrue(hasattr(layer.fa_v, "offset")) + + # Step 2: Simulate weight loading + self.mock_tp_rank.return_value = 0 + self.mock_tp_size.return_value = 1 + + # Create dummy weights + q_scale = torch.randn(32, 1) + k_scale = torch.randn(1, 1) + v_scale = torch.randn(1, 1) + q_offset = torch.randint(-128, 127, (32, 1), dtype=torch.int8) + k_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) + v_offset = torch.randint(-128, 127, (1, 1), dtype=torch.int8) + + # Load weights using weight_loader + from vllm_ascend.quantization.methods.kv_c8 import weight_loader + + with torch.no_grad(): + weight_loader(layer.fa_q.scale, q_scale) + weight_loader(layer.fa_k.scale, k_scale) + weight_loader(layer.fa_v.scale, v_scale) + weight_loader(layer.fa_q.offset, q_offset) + weight_loader(layer.fa_k.offset, k_offset) + weight_loader(layer.fa_v.offset, v_offset) + + # Verify weights were loaded correctly + self.assertTrue(torch.all(layer.fa_q.scale == q_scale)) + self.assertTrue(torch.all(layer.fa_k.scale == k_scale)) + self.assertTrue(torch.all(layer.fa_v.scale == v_scale)) + + # Step 3: Process after loading + method.process_weights_after_loading(layer) + + # Verify processed parameters + self.assertTrue(hasattr(layer, "fak_descale")) + self.assertTrue(hasattr(layer, "fak_offset")) + self.assertTrue(hasattr(layer, "quant_kscale")) + + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file From 4922349f10e0a1d1a6853df83d6eee8189d947fa Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Mon, 16 Mar 2026 10:23:52 +0800 Subject: [PATCH 09/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e4f3166ff52..01caa5d03c2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2826,7 +2826,6 @@ def _reshape_kv_cache_tensors( kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size ) - dtype = kv_cache_spec.dtype if not self.model_config.use_mla: k_shape = kv_cache_shape[1:] v_shape = k_shape @@ -2861,7 +2860,7 @@ def _reshape_kv_cache_tensors( kv_cache_spec.num_kv_heads, index_head_dim, ) - dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape) + dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape) kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) else: kv_caches[layer_name] = (k_cache, v_cache) @@ -3142,22 +3141,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, cache_dtype_str=self.vllm_config.cache_config.cache_dtype, ) - elif spec := attn_module.get_kv_cache_spec(self.vllm_config): - kv_cache_spec[layer_name] = spec - assert isinstance(spec, MLAAttentionSpec) - from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec - if getattr(attn_module.impl, "fa_quant_layer", False): - head_size = attn_module.head_size + attn_module.qk_rope_head_dim - dtype, cache_dtype_str = attn_module.impl.dtype, None - else: - head_size, dtype, cache_dtype_str = spec.head_size, spec.dtype, spec.cache_dtype_str - kv_cache_spec[layer_name] = AscendMLAAttentionSpec( - block_size=spec.block_size, - num_kv_heads=spec.num_kv_heads, + elif getattr(attn_module.impl, "fa_quant_layer", False): + head_size = attn_module.head_size + attn_module.qk_rope_head_dim + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=self.block_size, + num_kv_heads=attn_module.num_kv_heads, head_size=head_size, - dtype=dtype, - cache_dtype_str=cache_dtype_str, + dtype=attn_module.impl.dtype, + cache_dtype_str=None, ) + elif spec := attn_module.get_kv_cache_spec(self.vllm_config): + kv_cache_spec[layer_name] = spec elif isinstance(attn_module, MambaBase): mamba_layers[layer_name] = attn_module From c3b82857c9f4fe3d7ef9d7ef2aad5a07aa45f723 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Mon, 16 Mar 2026 10:32:33 +0800 Subject: [PATCH 10/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 01caa5d03c2..7b5f4591f6a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2671,12 +2671,6 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if self.model_config.use_mla: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim - + self.model_config.hf_text_config.kv_lora_rank - ) - dsa_k_cache_factor = None dsa_k_cache_size = None if not self.model_config.use_mla: @@ -2693,11 +2687,12 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str else: kv_head_dim_list = [ self.model_config.hf_text_config.kv_lora_rank, - self.model_config.hf_text_config.qk_rope_head_dim + self.model_config.hf_text_config.qk_rope_head_dim, ] if self.is_kv_consumer and self.vllm_config.quant_config is not None: - k_tensor_split_factor, v_tensor_split_factor = (self.vllm_config.quant_config. - get_kv_quant_split_factor(layer_name, kv_head_dim_list)) + k_tensor_split_factor, v_tensor_split_factor = ( + self.vllm_config.quant_config.get_kv_quant_split_factor(layer_name, kv_head_dim_list) + ) else: k_tensor_split_factor, v_tensor_split_factor = calc_split_factor(kv_head_dim_list) From 512c62ee1a16a9c81c131cfe1ef8151c325d2f90 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Mon, 16 Mar 2026 16:57:13 +0800 Subject: [PATCH 11/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 39 ++++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 859cef73801..3e1399bfb70 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -117,6 +117,7 @@ from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer from vllm_ascend.spec_decode.suffix_proposer import AscendSuffixDecodingProposer from vllm_ascend.utils import ( + calc_split_factor, check_gdn_layer, enable_sp, enable_sp_by_pass, @@ -2683,12 +2684,6 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if self.model_config.use_mla: - head_size = ( - self.model_config.hf_text_config.qk_rope_head_dim - + self.model_config.hf_text_config.kv_lora_rank - ) - if not self.model_config.use_mla: # for non-mla model, use FullAttentionSpec k_tensor_split_factor = 2.0 @@ -2703,8 +2698,16 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] else: # for other deepseek models, use MLAAttentionSpec - k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank - v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim + kv_head_dim_list = [ + self.model_config.hf_text_config.kv_lora_rank, + self.model_config.hf_text_config.qk_rope_head_dim, + ] + if self.is_kv_consumer and self.vllm_config.quant_config is not None: + k_tensor_split_factor, v_tensor_split_factor = ( + self.vllm_config.quant_config.get_kv_quant_split_factor(layer_name, kv_head_dim_list) + ) + else: + k_tensor_split_factor, v_tensor_split_factor = calc_split_factor(kv_head_dim_list) k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor) v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor) @@ -2881,8 +2884,13 @@ def _reshape_kv_cache_tensors( num_kv_heads, self.model_config.hf_text_config.qk_rope_head_dim, ] - k_cache = raw_k_tensor.view(kv_cache_spec.dtype).view(k_shape) - v_cache = raw_v_tensor.view(kv_cache_spec.dtype).view(v_shape) + k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype + if self.is_kv_consumer and self.vllm_config.quant_config is not None: + k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( + layer_name, kv_cache_spec.dtype, self.model_config + ) + k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape) + v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape) if self.use_sparse: dsa_k_cache_shape = ( @@ -3199,12 +3207,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: elif spec := attn_module.get_kv_cache_spec(self.vllm_config): assert isinstance(spec, MLAAttentionSpec) from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec + if getattr(attn_module.impl, "fa_quant_layer", False): + head_size = attn_module.head_size + attn_module.qk_rope_head_dim + dtype, cache_dtype_str = attn_module.impl.dtype, None + else: + head_size, dtype, cache_dtype_str = spec.head_size, spec.dtype, spec.cache_dtype_str kv_cache_spec[layer_name] = AscendMLAAttentionSpec( block_size=spec.block_size, num_kv_heads=spec.num_kv_heads, - head_size=spec.head_size, - dtype=spec.dtype, - cache_dtype_str=spec.cache_dtype_str, + head_size=head_size, + dtype=dtype, + cache_dtype_str=cache_dtype_str, ) elif isinstance(attn_module, MambaBase): From ef821bdfe66a1356fb6bbbc7edf86d576d77792f Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Mon, 16 Mar 2026 18:14:22 +0800 Subject: [PATCH 12/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- vllm_ascend/quantization/methods/kv_c8.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/quantization/methods/kv_c8.py b/vllm_ascend/quantization/methods/kv_c8.py index 21d090a6abd..8a700484405 100644 --- a/vllm_ascend/quantization/methods/kv_c8.py +++ b/vllm_ascend/quantization/methods/kv_c8.py @@ -37,7 +37,6 @@ def create_weights(self, layer: torch.nn.Module) -> None: setattr(layer, name, torch.nn.Module()) params_dict = {} dtype = torch.get_default_dtype() - layer.num_kv_heads = 1 params_dict["fa_q.scale"] = torch.empty((layer.num_heads, 1), dtype=dtype) params_dict["fa_k.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) params_dict["fa_v.scale"] = torch.empty((layer.num_kv_heads, 1), dtype=dtype) From 2ad6a4a98a0ee80136ed3ec31907d1f9c780c376 Mon Sep 17 00:00:00 2001 From: Wang Kunpeng <1289706727@qq.com> Date: Mon, 16 Mar 2026 19:34:53 +0800 Subject: [PATCH 13/13] fix ci Signed-off-by: Wang Kunpeng <1289706727@qq.com> --- tests/ut/quantization/test_kv_c8.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/ut/quantization/test_kv_c8.py b/tests/ut/quantization/test_kv_c8.py index e29a0c5c630..cfafefc9143 100644 --- a/tests/ut/quantization/test_kv_c8.py +++ b/tests/ut/quantization/test_kv_c8.py @@ -226,7 +226,7 @@ def setUp(self): # Create a real nn.Module for testing self.layer = nn.Module() self.layer.num_heads = 32 - self.layer.num_kv_heads = 8 + self.layer.num_kv_heads = 1 def tearDown(self): """Clean up after each test""" @@ -252,17 +252,6 @@ def test_create_weights_adds_submodules(self): self.assertIsInstance(self.layer.fa_k, nn.Module) self.assertIsInstance(self.layer.fa_v, nn.Module) - def test_create_weights_sets_num_kv_heads_to_one(self): - """Test that num_kv_heads is set to 1""" - method = self.method_class() - - with patch("torch.empty") as mock_empty: - mock_empty.return_value = torch.zeros(1, 1) - - method.create_weights(self.layer) - - self.assertEqual(self.layer.num_kv_heads, 1) - def test_create_weights_creates_correct_tensors(self): """Test that create_weights creates tensors with correct shapes and dtypes""" method = self.method_class() @@ -421,7 +410,7 @@ def test_complete_workflow(self): # Create real layer layer = nn.Module() layer.num_heads = 32 - layer.num_kv_heads = 8 + layer.num_kv_heads = 1 # Step 1: Create weights method.create_weights(layer)