diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py index 7fd3e739412..63228bdb752 100644 --- a/vllm_ascend/attention/context_parallel/sfa_cp.py +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -5,10 +5,12 @@ import torch_npu from vllm.config import VllmConfig from vllm.distributed import get_dcp_group, get_pcp_group +from vllm.triton_utils import HAS_TRITON from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills +from vllm_ascend.ops.triton.rope import rope_forward_triton_siso M = TypeVar("M", bound=AscendSFAMetadata) @@ -273,42 +275,33 @@ def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tens def indexer_select_post_process( self, x: torch.Tensor, - qr: torch.Tensor, - q: torch.Tensor | None, - k: torch.Tensor, + q_c: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, actual_seq_lengths_query: torch.Tensor, actual_seq_lengths_key: torch.Tensor, - need_gather_q_kv: bool = False, ): - if q is None: - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - cos_q, sin_q = cos, sin + weights, _ = self.weights_proj(x) - q_pe, q_nope = torch.split( - q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + if HAS_TRITON: + q_li = rope_forward_triton_siso( + q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + ) + else: + q_li_pe, q_li_nope = torch.split( + q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 ) # [b,s,64,64+64] - q_pe = q_pe.unsqueeze(2) - q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) - q_pe = q_pe.squeeze(2) - q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + q_li_pe = q_li_pe.unsqueeze(2) + q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin) + q_li_pe = q_li_pe.squeeze(2) + q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] - if kv_cache is not None: - if self.is_kv_producer: - attn_metadata.reshape_cache_event = torch.npu.Event() - torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) - ) # b, s, n, d - if self.is_kv_producer: - attn_metadata.reshape_cache_event.record() - - weights, _ = self.weights_proj(x) - weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) + q = q_li key = kv_cache[2] assert attn_metadata.sfa_cp_metadata is not None diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 64d21ef27d4..22def96c5e6 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -30,6 +30,7 @@ transdata, wait_for_kv_layer_from_connector, ) +from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, @@ -38,7 +39,7 @@ register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.ops.triton.rope import rope_forward_triton +from vllm_ascend.ops.triton.rope import rope_forward_triton_siso from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_ND, @@ -46,6 +47,7 @@ dispose_layer, enable_dsa_cp, enable_dsa_cp_with_layer_shard, + enable_dsa_cp_with_o_proj_tp, get_weight_prefetch_method, maybe_trans_nz, ) @@ -393,8 +395,8 @@ def __init__( ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - # In sfa, prefill and decode have the same calculation formula, - # so do not distinguish between prefill and decode here. + # The MLAPO operator fuses the pre-processing steps on Q/K/V in MLA into a single operator + # NOTE: it imposes a limit on the number of input tokens and conflicts with FlashComm self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO assert self.indexer is not None, "Indexer is required for DSA." @@ -419,21 +421,29 @@ def __init__( self.is_rope_neox_style = False self.use_torch_npu_lightning_indexer = True + # Effective in SFA when FlashComm is enabled. self.enable_dsa_cp = enable_dsa_cp() - self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard() + + # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. + self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() + + # use original TP o_proj weight in PD mix stage, and full gather + # for o_proj weight for prefill stage. + self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() + if self.enable_dsa_cp: self.local_num_heads = self.num_heads * self.tp_size - if self.enable_dsa_cp_prefill_only: - self.layer_sharding_kwargs = [] - for layer_name in get_ascend_config().layer_sharding or []: - if layer_name in kwargs: - self.layer_sharding_kwargs.append(kwargs[layer_name]) - else: - logger.warning_once( - f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " - "skipping sharding configuration" - ) - register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) + if self.enable_dsa_cp_with_layer_shard: + self.layer_sharding_kwargs = [] + for layer_name in get_ascend_config().layer_sharding or []: + if layer_name in kwargs: + self.layer_sharding_kwargs.append(kwargs[layer_name]) + else: + logger.warning_once( + f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " + "skipping sharding configuration" + ) + register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. @@ -469,7 +479,7 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory dispose_layer(self.kv_b_proj) if self.enable_dsa_cp: - if self.enable_dsa_cp_prefill_only: + if self.enable_dsa_cp_with_layer_shard: for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) @@ -501,100 +511,6 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - def _v_up_proj(self, x): - num_input_tokens, _, _ = x.shape - if ( - x.dtype in [torch.float16, torch.bfloat16] - and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") - and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS - ): - x = x.view(-1, self.local_num_heads, self.kv_lora_rank) - res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) - torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) - x = res.reshape(-1, self.local_num_heads * self.v_head_dim) - else: - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) - # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) - return x - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = ( - self.q_proj(x)[0] - .view(-1, self.local_num_heads, self.qk_head_dim) - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - ) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe - - def _get_full_kv(self, k, attn_metadata): - return k - - def exec_kv( - self, - kv_no_split: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - kv_cache: tuple, - slots: torch.Tensor, - attn_metadata: M, - ): - B = kv_no_split.shape[0] - N = self.num_kv_heads - S = 1 - # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA" - - if self.enable_dsa_cp: - _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( - kv_no_split, - self.kv_a_layernorm.weight, # type: ignore[union-attr] - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] - cache_mode=cache_mode, - is_output_kv=True, - ) - return k_pe, k_nope - else: - torch_npu.npu_kv_rmsnorm_rope_cache( - kv_no_split, - self.kv_a_layernorm.weight, # type: ignore[union-attr] - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] - cache_mode=cache_mode, - ) - return None, None - - def rope_single( - self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - B, N, D = x.shape - S = 1 - x = x.view(B, N, S, D) - x = torch_npu.npu_interleave_rope(x, cos, sin) - return x.view(B, N, D) - # Processing the input parameters for MLAPO by reordering and transposing # QKV(and part of Q) weight, applying RoPE-related dimension transformations, # and handling quantization parameters. @@ -672,356 +588,330 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): self.q_proj.quant_bias = None torch.npu.empty_cache() - def _sfa_preprocess_decode( + def forward_mha( self, - hidden_states: torch.Tensor, - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, - need_gather_q_kv: bool, - num_input_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), need_gather_q_kv) - k_nope, k_pe = kv_cache[0], kv_cache[1] - ql_nope = torch.empty( - (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - q_pe = torch.empty( - (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - q_c = torch.empty( - (num_input_tokens, self.q_lora_rank), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - torch.ops._C_ascend.mla_preprocess( - hidden_states, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - attn_metadata.cos, - attn_metadata.sin, - self.W_UK_T, - k_nope, - k_pe, - attn_metadata.slot_mapping, - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=self.ctkv_scale, - q_nope_scale=self.q_nope_scale, - cache_mode="krope_ctkv", - quant_mode="per_tensor_quant_asymm", - enable_inner_out=True, - q_out0=ql_nope, - kv_cache_out0=k_nope, - q_out1=q_pe, - kv_cache_out1=k_pe, - inner_out=q_c, - ) - return hidden_states, ql_nope, q_pe, q_c + k_scale: torch.Tensor, + output: torch.Tensor, + ) -> None: + raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.") - def forward( + def forward_mqa( self, - layer_name, - hidden_states: torch.Tensor, # query in unified attn - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, - need_gather_q_kv: bool = False, - output: torch.Tensor | None = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - forward_context = get_forward_context() - if attn_metadata is None: - # Profiling run. - if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run: - for layer in self.layer_sharding_kwargs or []: - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) - return output.fill_(0) - - cos = attn_metadata.cos - sin = attn_metadata.sin - actual_seq_lengths_query = attn_metadata.cum_query_lens - actual_seq_lengths_key = attn_metadata.seq_lens - if self.enable_dsa_cp: - need_gather_q_kv = False - # Inputs and outputs may be padded for CUDA graphs - num_input_tokens = attn_metadata.num_input_tokens - output_padded = output + layer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.") - # all-gather o_proj weight for prefill stage of PD mix node - o_proj_full_handle = None - # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj - # weight for prefill stage. - should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in { - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding, - } + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) - if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: - hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - need_gather_q_kv=need_gather_q_kv, - num_input_tokens=num_input_tokens, - ) - q, k = self.indexer_select_pre_process( - x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv - ) - else: - assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." - weight_prefetch_method = get_weight_prefetch_method() - weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( - inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states - ) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_no_split = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" - q_c = self.q_a_layernorm(q_c) - # Process for Flash Comm V1 - if need_gather_q_kv: - q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv) - kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - kv_no_split.contiguous(), need_gather_q_kv - ) + def _init_o_proj_tp_full_params(self): + """ + Initialize TP-mode and Full-mode parameters for o_proj weight, + preparing for weight switching in PD mix stage. - q, k = self.indexer_select_pre_process( - x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv + For PD mix stage: + - Use original TP o_proj weight for decode phase + - Need full-gather o_proj weight from all TP ranks for prefill phase + """ + if AscendSFAImpl.o_proj_full_pool is None: + sample = self.o_proj.weight + AscendSFAImpl.o_proj_full_pool = torch.empty( + (sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device ) - wait_for_kv_layer_from_connector(layer_name) + # Save TP-mode parameters (original sharded weights) + self.o_proj_tp_weight = self.o_proj.weight.clone().detach() + self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach() + self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach() + self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach() - slot_mapping = attn_metadata.slot_mapping - if self.enable_dsa_cp: - assert attn_metadata.dsa_cp_context is not None - slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp - actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query - actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key + # Initially switch to TP mode for graph capture + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) - k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) + # Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks + self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size) + self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size) + self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size) - if self.enable_dsa_cp: - assert k_pe is not None - assert k_nope is not None - # support all_gather kv async for communication calculation overlap - fused_kv_no_split, kv_ag_handle = all_gather_async( - torch.cat( - [k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), k.view(-1, k.shape[-1])], - dim=1, - ), - get_tp_group(), - async_op=should_shard_weight, - ) + def _handle_o_proj_weight_switch_and_forward( + self, + attn_output: torch.Tensor, + output: torch.Tensor, + o_proj_full_handle: torch.distributed.Work | None, + should_shard_weight: bool, + ) -> tuple[torch.Tensor, bool]: + """ + Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation. + """ + # Gather o_proj weight from all TP ranks for Full-mode computation + if should_shard_weight: + # Wait for the completion of o_proj weight all-gather operation + if o_proj_full_handle is not None: + o_proj_full_handle.wait() - ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) - q_pe = self.rope_single(q_pe, cos, sin) + # Switch o_proj to Full-mode (gathered weight from all TP ranks) + self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) + self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset) - if self.enable_dsa_cp: - if kv_ag_handle is not None: - kv_ag_handle.wait() + # Apply quantization method and execute forward computation + output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output) - if self.enable_dsa_cp_prefill_only: - for layer in self.layer_sharding_kwargs or []: - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) - elif should_shard_weight: - _, o_proj_full_handle = all_gather_async( - self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool - ) + # Switch o_proj back to TP-mode for subsequent decode operations + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) - if kv_cache is not None: - assert fused_kv_no_split is not None - k_pe, k_nope, k = fused_kv_no_split.split( - [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 - ) - slot_mapping = attn_metadata.slot_mapping.view(-1, 1) - torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope) - torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe) + return output, False + else: + # For decode scenario: perform all-to-all communication on o_proj input activations + # Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim] + send = ( + attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) + .permute(1, 0, 2) + .reshape(-1, self.num_heads * self.v_head_dim) + ) - k = self._get_full_kv(k, attn_metadata) - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) - ) # b, s, n, d + attn_output = torch.empty_like(send) + torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) - topk_indices = self.indexer_select_post_process( - x=hidden_states, - qr=q_c, - q=q, - k=k, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - cos=cos, - sin=sin, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - need_gather_q_kv=need_gather_q_kv, - ) + return attn_output, True - attn_output = self._execute_sparse_flash_attention_process( - ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key - ) + def _get_full_kv(self, k, attn_metadata): + return k - attn_output = self._v_up_proj(attn_output) - weight_prefetch_method = get_weight_prefetch_method() - weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( - inputs=self.o_proj.weight, - dependency=attn_output, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - linear_layer=self.o_proj, - ) + def exec_kv( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: tuple, + slots: torch.Tensor, + attn_metadata: M, + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA" - if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only: - # When using SFA-CP with pd mixed, o_proj has two cases: - # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. - # 2. decode: all-to-all the hidden_state before the o_proj forward. - result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( - attn_output=attn_output, - output=output, - o_proj_full_handle=o_proj_full_handle, - should_shard_weight=should_shard_weight, + if self.enable_dsa_cp: + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, # type: ignore[union-attr] + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] + cache_mode=cache_mode, + is_output_kv=True, ) - if not require_o_proj_forward: - return result - attn_output = result - - output[...] = self.o_proj(attn_output)[0] + return k_pe, k_nope + else: + torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, # type: ignore[union-attr] + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] + cache_mode=cache_mode, + ) + return None, None - maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = ( + self.q_proj(x)[0] + .view(-1, self.local_num_heads, self.qk_head_dim) + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) - return output_padded + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe - def _execute_sparse_flash_attention_process( - self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key - ): - block_table = attn_metadata.block_table - kv = kv_cache[0] - key_rope = kv_cache[1] + def _v_up_proj(self, x): + num_input_tokens, _, _ = x.shape + if ( + x.dtype in [torch.float16, torch.bfloat16] + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") + and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS + ): + x = x.view(-1, self.local_num_heads, self.kv_lora_rank) + res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) + x = res.reshape(-1, self.local_num_heads * self.v_head_dim) + else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) + return x - attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( - query=ql_nope, - key=kv, - value=kv, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=block_table, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_kv=actual_seq_lengths_key, - query_rope=q_pe, - key_rope=key_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, + def _sfa_preprocess_with_mlapo( + self, + hidden_states: torch.Tensor, + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, + slot_mapping: torch.Tensor, + num_input_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + k_nope, k_pe = kv_cache[0], kv_cache[1] + ql_nope = torch.empty( + (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, ) - return attn_output + q_pe = torch.empty( + (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_c = torch.empty( + (num_input_tokens, self.q_lora_rank), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.W_UK_T, + k_nope, + k_pe, + slot_mapping, + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + enable_inner_out=True, + q_out0=ql_nope, + kv_cache_out0=k_nope, + q_out1=q_pe, + kv_cache_out1=k_pe, + inner_out=q_c, + ) + return hidden_states, ql_nope, q_pe, q_c def indexer_select_pre_process( self, x: torch.Tensor, - qr: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, - need_gather_q_kv: bool = False, ): - k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] - k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(k_proj, need_gather_q_kv) - k = self.k_norm(k_proj).unsqueeze(1) - k = k.view(-1, 1, self.head_dim) + k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + k_li = self.k_norm(k_li).unsqueeze(1) + k_li = k_li.view(-1, 1, self.head_dim) if HAS_TRITON: - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) - q, k = rope_forward_triton( - q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + k_li = rope_forward_triton_siso( + k_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style ) else: - k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) + k_li_pe, k_li_nope = torch.split( + k_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + ) cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) - k_pe = k_pe.squeeze(2) + k_li_pe = k_li_pe.unsqueeze(2) + k_li_pe = torch_npu.npu_interleave_rope(k_li_pe, cos, sin) + k_li_pe = k_li_pe.squeeze(2) - k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] - q = None + k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128] - return q, k + return k_li def indexer_select_post_process( self, x: torch.Tensor, - qr: torch.Tensor, - q: torch.Tensor | None, - k: torch.Tensor, + q_c: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, actual_seq_lengths_query: torch.Tensor, actual_seq_lengths_key: torch.Tensor, - need_gather_q_kv: bool = False, ): - if q is None: - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - cos_q, sin_q = cos, sin + weights, _ = self.weights_proj(x) - q_pe, q_nope = torch.split( - q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + if HAS_TRITON: + q_li = rope_forward_triton_siso( + q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + ) + else: + q_li_pe, q_li_nope = torch.split( + q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 ) # [b,s,64,64+64] - q_pe = q_pe.unsqueeze(2) - q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) - q_pe = q_pe.squeeze(2) - q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - - if kv_cache is not None: - if self.is_kv_producer: - attn_metadata.reshape_cache_event = torch.npu.Event() - torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) - ) # b, s, n, d - if self.is_kv_producer: - attn_metadata.reshape_cache_event.record() - - weights, _ = self.weights_proj(x) - weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) - - key = kv_cache[2] - block_table = attn_metadata.block_table + q_li_pe = q_li_pe.unsqueeze(2) + q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin) + q_li_pe = q_li_pe.squeeze(2) + q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. # So two branches are maintained temporarily. # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. if self.use_torch_npu_lightning_indexer: topk_indices, _ = torch_npu.npu_lightning_indexer( - query=q, - key=key, + query=q_li, + key=kv_cache[2], weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, - block_table=block_table, + block_table=attn_metadata.block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, @@ -1029,12 +919,12 @@ def indexer_select_post_process( ) else: topk_indices = torch.ops._C_ascend.npu_lightning_indexer( - query=q, - key=key, + query=q_li, + key=kv_cache[2], weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, - block_table=block_table, + block_table=attn_metadata.block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, @@ -1042,101 +932,212 @@ def indexer_select_post_process( ) return topk_indices - def _init_o_proj_tp_full_params(self): - """ - Initialize TP-mode and Full-mode parameters for o_proj weight, - preparing for weight switching in PD mix stage. + def _execute_sparse_flash_attention_process( + self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ): + block_table = attn_metadata.block_table + kv = kv_cache[0] + key_rope = kv_cache[1] - For PD mix stage: - - Use original TP o_proj weight for decode phase - - Need full-gather o_proj weight from all TP ranks for prefill phase - """ - if AscendSFAImpl.o_proj_full_pool is None: - sample = self.o_proj.weight - AscendSFAImpl.o_proj_full_pool = torch.empty( - (sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( + query=ql_nope, + key=kv, + value=kv, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=block_table, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_key, + query_rope=q_pe, + key_rope=key_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + return attn_output + + def forward( + self, + layer_name, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + forward_context = get_forward_context() + if attn_metadata is None: + # Profiling run. + if self.enable_dsa_cp_with_layer_shard and not forward_context.in_profile_run: + for layer in self.layer_sharding_kwargs or []: + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + return output.fill_(0) + + cos = attn_metadata.cos + sin = attn_metadata.sin + slot_mapping = attn_metadata.slot_mapping + slot_mapping_cp = None + if self.enable_dsa_cp: + assert attn_metadata.dsa_cp_context is not None + slot_mapping_cp = attn_metadata.dsa_cp_context.slot_mapping_cp + actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query + actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key + else: + actual_seq_lengths_query = attn_metadata.cum_query_lens + actual_seq_lengths_key = attn_metadata.seq_lens + + # Inputs and outputs may be padded for CUDA graphs + num_input_tokens = attn_metadata.num_input_tokens + output_padded = output + + # all-gather o_proj weight for prefill stage of PD mix node + o_proj_full_handle = None + # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj + # weight for prefill stage. + full_gather_o_proj_enabled = self.enable_dsa_cp_with_o_proj_tp and attn_metadata.attn_state not in { + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding, + } + + # run mlapo ops when dsa-cp is disabled, and ensure that num_tokens satisfies the count limitation + if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_with_mlapo( + hidden_states=hidden_states, + kv_cache=kv_cache, + cos=cos, + sin=sin, + slot_mapping=slot_mapping, + num_input_tokens=num_input_tokens, + ) + k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) + # native + else: + assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_no_split = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, ) + assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" + q_c = self.q_a_layernorm(q_c) - # Save TP-mode parameters (original sharded weights) - self.o_proj_tp_weight = self.o_proj.weight.clone().detach() - self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach() - self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach() - self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach() + k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) - # Initially switch to TP mode for graph capture - self.o_proj.weight.set_(self.o_proj_tp_weight) - self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) + wait_for_kv_layer_from_connector(layer_name) - # Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks - self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size) - self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size) - self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size) + if self.enable_dsa_cp: + assert slot_mapping_cp is not None + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping_cp, attn_metadata) + else: + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) - def _handle_o_proj_weight_switch_and_forward( - self, - attn_output: torch.Tensor, - output: torch.Tensor, - o_proj_full_handle: torch.distributed.Work | None, - should_shard_weight: bool, - ) -> tuple[torch.Tensor, bool]: - """ - Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation. - """ - # Gather o_proj weight from all TP ranks for Full-mode computation - if should_shard_weight: - # Wait for the completion of o_proj weight all-gather operation - if o_proj_full_handle is not None: - o_proj_full_handle.wait() + if self.enable_dsa_cp: + assert k_pe is not None + assert k_nope is not None + async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled + # support all_gather kv async for communication calculation overlap + fused_kv_no_split, kv_ag_handle = all_gather_async( + torch.cat( + [ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + k_li.view(-1, k_li.shape[-1]), + ], + dim=1, + ), + get_tp_group(), + async_op=async_op, + ) - # Switch o_proj to Full-mode (gathered weight from all TP ranks) - self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) - self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset) + ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) + q_pe = self.rope_single(q_pe, cos, sin) - # Apply quantization method and execute forward computation - output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output) + if self.enable_dsa_cp: + if kv_ag_handle is not None: + kv_ag_handle.wait() - # Switch o_proj back to TP-mode for subsequent decode operations - self.o_proj.weight.set_(self.o_proj_tp_weight) - self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) + if self.enable_dsa_cp_with_layer_shard: + for layer in self.layer_sharding_kwargs or []: + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + elif full_gather_o_proj_enabled: + _, o_proj_full_handle = all_gather_async( + self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool + ) - return output, False - else: - # For decode scenario: perform all-to-all communication on o_proj input activations - # Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim] - send = ( - attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) - .permute(1, 0, 2) - .reshape(-1, self.num_heads * self.v_head_dim) - ) + if kv_cache is not None: + assert fused_kv_no_split is not None + k_pe, k_nope, k_li = fused_kv_no_split.split( + [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 + ) + k_nope = k_nope.view(k_nope.shape[0], 1, -1) + k_pe = k_pe.view(k_pe.shape[0], 1, -1) + DeviceOperator.reshape_and_cache( + key=k_nope[: attn_metadata.num_actual_tokens], + value=k_pe[: attn_metadata.num_actual_tokens], + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens], + ) - attn_output = torch.empty_like(send) - torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) + k_li = self._get_full_kv(k_li, attn_metadata) - return attn_output, True + if kv_cache is not None: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1]) + ) # b, s, n, d + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() - def forward_mha( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: M, - k_scale: torch.Tensor, - output: torch.Tensor, - ) -> None: - raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.") + topk_indices = self.indexer_select_post_process( + x=hidden_states, + q_c=q_c, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + cos=cos, + sin=sin, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + ) - def forward_mqa( - self, - q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: M, - layer, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.") + attn_output = self._execute_sparse_flash_attention_process( + ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ) + + attn_output = self._v_up_proj(attn_output) + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.o_proj.weight, + dependency=attn_output, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + linear_layer=self.o_proj, + ) + + if self.enable_dsa_cp_with_o_proj_tp: + # When using SFA-CP with pd mixed, o_proj has two cases: + # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. + # 2. decode: all-to-all the hidden_state before the o_proj forward. + result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( + attn_output=attn_output, + output=output, + o_proj_full_handle=o_proj_full_handle, + should_shard_weight=full_gather_o_proj_enabled, + ) + if not require_o_proj_forward: + return result + attn_output = result + + output[...] = self.o_proj(attn_output)[0] + + maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) + + return output_padded diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index ad863e40910..909065179a6 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -146,6 +146,79 @@ def _triton_rope( tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) +@triton.jit +def _triton_rope_siso( + qk_ptr, + qk_row_stride, + cos_ptr, + cos_row_stride, + sin_ptr, + sin_row_stride, + cos_sin_ptr, + cos_sin_row_stride, + pos_ptr, + num_tokens, + n_h: tl.constexpr, + hd: tl.constexpr, + rope_dim: tl.constexpr, + pad_n_h: tl.constexpr, + pad_rope_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_NEOX_STYLE: tl.constexpr, + USE_COS_SIN: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + row_block_size = tl.num_programs(0) + + for row_idx in tl.range(pid, num_tokens, row_block_size): + qk_start_ptr = qk_ptr + row_idx * qk_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + cos_offsets = tl.arange(0, pad_rope_dim // 2) + sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim) + cos_mask = cos_offsets < (rope_dim // 2) + if USE_COS_SIN: + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride + cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32) + else: + cos_start_ptr = cos_ptr + row_idx * cos_row_stride + sin_start_ptr = sin_ptr + row_idx * sin_row_stride + cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + if IS_NEOX_STYLE: + first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :] + else: + first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :]) + + first_mask = (tl.arange(0, pad_n_h)[:, None] < n_h) & ( + tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2) + ) + qk_tile_1 = tl.load(qk_start_ptr + first_half_offsets, mask=first_mask, other=0).to(sin_row.dtype) + + # right half of the head + if IS_NEOX_STYLE: + second_half_offsets = first_half_offsets + (rope_dim // 2) + else: + second_half_offsets = first_half_offsets + 1 + second_mask = first_mask + qk_tile_2 = tl.load(qk_start_ptr + second_half_offsets, mask=second_mask, other=0).to(sin_row.dtype) + + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_qk_tile_1 = qk_tile_1 * cos_row - qk_tile_2 * sin_row + tl.store(qk_start_ptr + first_half_offsets, new_qk_tile_1, mask=first_mask) + + def rope_forward_triton( q: torch.Tensor, k: torch.Tensor, @@ -237,3 +310,83 @@ def rope_forward_triton( "Please check whether you call rope_forward_triton correctly." ) return q, k + + +def rope_forward_triton_siso( + qk: torch.Tensor, + cos: torch.Tensor = None, + sin: torch.Tensor = None, + cos_sin_cache: torch.Tensor = None, + positions: torch.Tensor = None, + rope_dim: int = -1, + is_neox_style: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + if not qk.is_contiguous(): + qk = qk.contiguous() + + num_tokens, n_head, head_dim = qk.shape + assert rope_dim <= head_dim + pad_rope_dim = triton.next_power_of_2(rope_dim) + pad_n_head = triton.next_power_of_2(n_head) + BLOCK_SIZE = pad_n_head + num_vectorcore = get_vectorcore_num() + n_row = min(num_tokens, num_vectorcore) + + if cos_sin_cache is not None and positions is not None: + assert positions.shape[0] == num_tokens + _triton_rope_siso[(n_row,)]( + qk, + qk.stride(0), + None, + None, + None, + None, + cos_sin_cache, + cos_sin_cache.stride(0), + positions, + num_tokens, + n_head, + head_dim, + rope_dim, + pad_n_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=True, + ) + elif cos is not None and sin is not None: + assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens + cos = cos.view(num_tokens, -1) + sin = sin.view(num_tokens, -1) + if rope_dim == -1: + # If rope_dim is not specified, we assume that input cos/sin is not + # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 + rope_dim = cos.shape[-1] * 2 + _triton_rope_siso[(n_row,)]( + qk, + qk.stride(0), + cos, + cos.stride(0), + sin, + sin.stride(0), + None, + None, + None, + num_tokens, + n_head, + head_dim, + rope_dim, + pad_n_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=False, + ) + else: + raise ValueError( + "Currently, rope_forward_triton supports passing:\n" + "1. positions and original cos_sin_cache.\n" + "2. cos and sin which are already selected by positions\n" + "Please check whether you call rope_forward_triton correctly." + ) + return qk diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index de0e5d12acc..c96975473f2 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1129,10 +1129,24 @@ def enable_dsa_cp_with_layer_shard() -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() + # because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be + # effectively hidden, it is enabled only during the prefill stage. is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer return is_prefill_instance +@lru_cache(maxsize=1) +def enable_dsa_cp_with_o_proj_tp() -> bool: + if not enable_dsa_cp(): + return False + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + # if is PD mix stage, using original TP o_proj weight, and also need to + # full gather for o_proj weight for prefill stage. + return vllm_config.kv_transfer_config is None + + def check_gdn_layer(vllm_config) -> bool: """ gdn layer is marked with `linear_attention`.