diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index a32817dc6ee..e2f93117694 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -24,6 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) +from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, post_process_after_loading_for_shard_weight_series, reach_layer_for_shard_weight_series, @@ -226,7 +227,10 @@ def build( cos = cos[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad] - slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] + slot_mapping_cp = torch.full(size=(num_tokens_per_device, ), + fill_value=-1, + dtype=slot_mapping.dtype, + device=slot_mapping.device) assert cos.shape[0] == num_tokens_per_device, \ f"cos.shape[0] must be equal to num_tokens_per_device, \ got {cos.shape[0]} and {num_tokens_per_device}" @@ -503,7 +507,6 @@ def exec_kv( sin: torch.Tensor, kv_cache: Tuple, slots: torch.Tensor, - slots_cp: Optional[torch.Tensor], ): B = kv_no_split.shape[0] N = self.num_kv_heads @@ -514,30 +517,19 @@ def exec_kv( cache_mode = "PA" if self.enable_sfa_cp: - assert slots_cp is not None _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, cos, sin, - slots_cp.to(torch.int64), + slots.to(torch.int64), kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, cache_mode=cache_mode, is_output_kv=True, ) - # TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 - k_pe = get_tp_group().all_gather(k_pe, 0) - k_nope = get_tp_group().all_gather(k_nope, 0) - - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_( - kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1), - k_nope.view(-1, k_nope.shape[-1])) - torch_npu.npu_scatter_nd_update_( - kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1), - k_pe.view(-1, k_pe.shape[-1])) + return k_pe, k_nope else: torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, @@ -550,6 +542,7 @@ def exec_kv( epsilon=self.kv_a_layernorm.variance_epsilon, cache_mode=cache_mode, ) + return None, None def rope_single( self, @@ -742,6 +735,7 @@ def forward( if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) return output.fill_(0) + has_prefill = attn_metadata.has_prefill cos = attn_metadata.cos sin = attn_metadata.sin @@ -762,6 +756,12 @@ def forward( need_gather_q_kv=need_gather_q_kv, num_input_tokens=attn_metadata.num_input_tokens, ) + q, k = self.indexer_select_pre_process( + x=hidden_states, + qr=q_c, + cos=cos, + sin=sin, + need_gather_q_kv=need_gather_q_kv) else: assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, @@ -780,31 +780,67 @@ def forward( kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( kv_no_split.contiguous(), need_gather_q_kv) + q, k = self.indexer_select_pre_process( + x=hidden_states, + qr=q_c, + cos=cos, + sin=sin, + need_gather_q_kv=need_gather_q_kv) + if has_prefill: wait_for_kv_layer_from_connector(layer_name) slot_mapping = attn_metadata.slot_mapping - slot_mapping_cp = None if self.enable_sfa_cp: assert attn_metadata.sfa_cp_context is not None - slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp + slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key - self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, - slot_mapping_cp) + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, + slot_mapping) - if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None: - for layer in (self.layer_sharding_kwargs or []): - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) + if self.enable_sfa_cp: + assert k_pe is not None + assert k_nope is not None + # support all_gather kv async for communication calculation overlap + fused_kv_no_split, kv_ag_handle = all_gather_async( + torch.cat([ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + k.view(-1, k.shape[-1]) + ], + dim=1), get_tp_group()) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) - topk_indices = self.indexer_select( + if self.enable_sfa_cp: + if kv_ag_handle is not None: + kv_ag_handle.wait() + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + + if kv_cache is not None: + assert fused_kv_no_split is not None + k_pe, k_nope, k = fused_kv_no_split.split([ + self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim + ], + dim=-1) + slot_mapping = attn_metadata.slot_mapping.view(-1, 1) + torch_npu.npu_scatter_nd_update_( + kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, + k_nope) + torch_npu.npu_scatter_nd_update_( + kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, + k_pe) + + topk_indices = self.indexer_select_post_process( x=hidden_states, qr=q_c, + q=q, + k=k, kv_cache=kv_cache, attn_metadata=attn_metadata, cos=cos, @@ -812,6 +848,7 @@ def forward( actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, need_gather_q_kv=need_gather_q_kv) + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, key=kv_cache[0], @@ -828,6 +865,7 @@ def forward( layout_kv="PA_BSND", sparse_mode=3, ) + attn_output = self._v_up_proj(attn_output, has_prefill) maybe_npu_prefetch(inputs=self.o_proj.weight, dependency=attn_output, @@ -836,22 +874,14 @@ def forward( output[...] = self.o_proj(attn_output)[0] return output_padded - def indexer_select( + def indexer_select_pre_process( self, x: torch.Tensor, qr: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, - actual_seq_lengths_query: torch.Tensor, - actual_seq_lengths_key: torch.Tensor, need_gather_q_kv: bool = False, ): - # q process in new stream - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( k_proj, need_gather_q_kv) @@ -859,6 +889,9 @@ def indexer_select( k = k.view(-1, 1, self.head_dim) if HAS_TRITON: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) q, k = rope_forward_triton(q, @@ -868,6 +901,38 @@ def indexer_select( rope_dim=self.qk_rope_head_dim, is_neox_style=True) else: + k_pe, k_nope = torch.split( + k, + [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + q = None + + return q, k + + def indexer_select_post_process( + self, + x: torch.Tensor, + qr: torch.Tensor, + q: Optional[torch.Tensor], + k: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + cos: torch.Tensor, + sin: torch.Tensor, + actual_seq_lengths_query: torch.Tensor, + actual_seq_lengths_key: torch.Tensor, + need_gather_q_kv: bool = False, + ): + if q is None: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos_q, sin_q = cos, sin cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) @@ -882,20 +947,6 @@ def indexer_select( q_pe = q_pe.squeeze(2) q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - k_pe, k_nope = torch.split( - k, - [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64+64] - - k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin) - k_pe = k_pe.squeeze(2) - - k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] - - if self.enable_sfa_cp: - k = get_tp_group().all_gather(k, 0) - if kv_cache is not None: torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view( diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index 70c57d288ec..3a624de209d 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -1,8 +1,9 @@ import os +from typing import Optional import torch import torch.distributed as dist -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group from vllm.forward_context import get_forward_context from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group, @@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor: offset += num_tokens_dp x = result return x + + +def all_gather_async(input: torch.Tensor, + group: GroupCoordinator, + output: Optional[torch.Tensor] = None, + async_op: bool = True): + if group.world_size == 1: + return input, None + if output is None: + input_size = input.size() + output_size = (input_size[0] * group.world_size, ) + input_size[1:] + output = torch.empty(output_size, + dtype=input.dtype, + device=input.device) + return output, dist.all_gather_into_tensor(output, + input, + group=group.device_group, + async_op=async_op) \ No newline at end of file diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 9874661ecc6..9994433e182 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1165,20 +1165,13 @@ def get_instance(*args, **kwargs): @lru_cache(maxsize=1) def enable_dsa_cp() -> bool: from vllm.config import get_current_vllm_config - vllm_config = get_current_vllm_config() - if vllm_config is None: - return False - - model_config = getattr(vllm_config, "model_config", None) - if model_config is None: - return False - - hf_text_config = getattr(model_config, "hf_text_config", None) - if hf_text_config is None: - return False - - return hasattr(hf_text_config, "index_topk") + is_ds_v32 = hasattr( + vllm_config.model_config, "hf_text_config") and hasattr( + vllm_config.model_config.hf_text_config, "index_topk") + if is_ds_v32 and enable_sp(): + return True + return False @lru_cache(maxsize=1)