From 2fd1eccf13221e1223666b3214a58dff940e5682 Mon Sep 17 00:00:00 2001 From: yydyzr Date: Mon, 9 Mar 2026 15:49:25 +0800 Subject: [PATCH] Revert "[perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (#6874)" --- .../attention/context_parallel/sfa_cp.py | 43 +- vllm_ascend/attention/sfa_v1.py | 955 +++++++++--------- vllm_ascend/ops/triton/rope.py | 153 --- vllm_ascend/utils.py | 14 - 4 files changed, 502 insertions(+), 663 deletions(-) diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py index 63228bdb752..7fd3e739412 100644 --- a/vllm_ascend/attention/context_parallel/sfa_cp.py +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -5,12 +5,10 @@ import torch_npu from vllm.config import VllmConfig from vllm.distributed import get_dcp_group, get_pcp_group -from vllm.triton_utils import HAS_TRITON from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills -from vllm_ascend.ops.triton.rope import rope_forward_triton_siso M = TypeVar("M", bound=AscendSFAMetadata) @@ -275,33 +273,42 @@ def gather_kv_cross_cp(self, kv_cache: torch.Tensor, valid_block_ids: torch.Tens def indexer_select_post_process( self, x: torch.Tensor, - q_c: torch.Tensor, + qr: torch.Tensor, + q: torch.Tensor | None, + k: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, actual_seq_lengths_query: torch.Tensor, actual_seq_lengths_key: torch.Tensor, + need_gather_q_kv: bool = False, ): - weights, _ = self.weights_proj(x) + if q is None: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos_q, sin_q = cos, sin - q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - if HAS_TRITON: - q_li = rope_forward_triton_siso( - q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style - ) - else: - q_li_pe, q_li_nope = torch.split( - q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 ) # [b,s,64,64+64] - q_li_pe = q_li_pe.unsqueeze(2) - q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin) - q_li_pe = q_li_pe.squeeze(2) - q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - q = q_li + if kv_cache is not None: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) + ) # b, s, n, d + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() + + weights, _ = self.weights_proj(x) + weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) key = kv_cache[2] assert attn_metadata.sfa_cp_metadata is not None diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 22def96c5e6..64d21ef27d4 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -30,7 +30,6 @@ transdata, wait_for_kv_layer_from_connector, ) -from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, @@ -39,7 +38,7 @@ register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.ops.triton.rope import rope_forward_triton_siso +from vllm_ascend.ops.triton.rope import rope_forward_triton from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_ND, @@ -47,7 +46,6 @@ dispose_layer, enable_dsa_cp, enable_dsa_cp_with_layer_shard, - enable_dsa_cp_with_o_proj_tp, get_weight_prefetch_method, maybe_trans_nz, ) @@ -395,8 +393,8 @@ def __init__( ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - # The MLAPO operator fuses the pre-processing steps on Q/K/V in MLA into a single operator - # NOTE: it imposes a limit on the number of input tokens and conflicts with FlashComm + # In sfa, prefill and decode have the same calculation formula, + # so do not distinguish between prefill and decode here. self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO assert self.indexer is not None, "Indexer is required for DSA." @@ -421,29 +419,21 @@ def __init__( self.is_rope_neox_style = False self.use_torch_npu_lightning_indexer = True - # Effective in SFA when FlashComm is enabled. self.enable_dsa_cp = enable_dsa_cp() - - # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. - self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() - - # use original TP o_proj weight in PD mix stage, and full gather - # for o_proj weight for prefill stage. - self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() - + self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard() if self.enable_dsa_cp: self.local_num_heads = self.num_heads * self.tp_size - if self.enable_dsa_cp_with_layer_shard: - self.layer_sharding_kwargs = [] - for layer_name in get_ascend_config().layer_sharding or []: - if layer_name in kwargs: - self.layer_sharding_kwargs.append(kwargs[layer_name]) - else: - logger.warning_once( - f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " - "skipping sharding configuration" - ) - register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) + if self.enable_dsa_cp_prefill_only: + self.layer_sharding_kwargs = [] + for layer_name in get_ascend_config().layer_sharding or []: + if layer_name in kwargs: + self.layer_sharding_kwargs.append(kwargs[layer_name]) + else: + logger.warning_once( + f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " + "skipping sharding configuration" + ) + register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. @@ -479,7 +469,7 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory dispose_layer(self.kv_b_proj) if self.enable_dsa_cp: - if self.enable_dsa_cp_with_layer_shard: + if self.enable_dsa_cp_prefill_only: for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) @@ -511,6 +501,100 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) + def _v_up_proj(self, x): + num_input_tokens, _, _ = x.shape + if ( + x.dtype in [torch.float16, torch.bfloat16] + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") + and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS + ): + x = x.view(-1, self.local_num_heads, self.kv_lora_rank) + res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) + x = res.reshape(-1, self.local_num_heads * self.v_head_dim) + else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) + return x + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = ( + self.q_proj(x)[0] + .view(-1, self.local_num_heads, self.qk_head_dim) + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def _get_full_kv(self, k, attn_metadata): + return k + + def exec_kv( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: tuple, + slots: torch.Tensor, + attn_metadata: M, + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA" + + if self.enable_dsa_cp: + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, # type: ignore[union-attr] + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] + cache_mode=cache_mode, + is_output_kv=True, + ) + return k_pe, k_nope + else: + torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, # type: ignore[union-attr] + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] + cache_mode=cache_mode, + ) + return None, None + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + # Processing the input parameters for MLAPO by reordering and transposing # QKV(and part of Q) weight, applying RoPE-related dimension transformations, # and handling quantization parameters. @@ -588,330 +672,356 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): self.q_proj.quant_bias = None torch.npu.empty_cache() - def forward_mha( + def _sfa_preprocess_decode( self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, - k_scale: torch.Tensor, - output: torch.Tensor, - ) -> None: - raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.") + need_gather_q_kv: bool, + num_input_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), need_gather_q_kv) + k_nope, k_pe = kv_cache[0], kv_cache[1] + ql_nope = torch.empty( + (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_pe = torch.empty( + (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_c = torch.empty( + (num_input_tokens, self.q_lora_rank), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + attn_metadata.cos, + attn_metadata.sin, + self.W_UK_T, + k_nope, + k_pe, + attn_metadata.slot_mapping, + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + enable_inner_out=True, + q_out0=ql_nope, + kv_cache_out0=k_nope, + q_out1=q_pe, + kv_cache_out1=k_pe, + inner_out=q_c, + ) + return hidden_states, ql_nope, q_pe, q_c - def forward_mqa( + def forward( self, - q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - kv_c_and_k_pe_cache: torch.Tensor, + layer_name, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, - layer, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.") - - def rope_single( - self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, + need_gather_q_kv: bool = False, + output: torch.Tensor | None = None, ) -> torch.Tensor: - B, N, D = x.shape - S = 1 - x = x.view(B, N, S, D) - x = torch_npu.npu_interleave_rope(x, cos, sin) - return x.view(B, N, D) + assert output is not None, "Output tensor must be provided." + forward_context = get_forward_context() + if attn_metadata is None: + # Profiling run. + if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run: + for layer in self.layer_sharding_kwargs or []: + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + return output.fill_(0) - def _init_o_proj_tp_full_params(self): - """ - Initialize TP-mode and Full-mode parameters for o_proj weight, - preparing for weight switching in PD mix stage. + cos = attn_metadata.cos + sin = attn_metadata.sin + actual_seq_lengths_query = attn_metadata.cum_query_lens + actual_seq_lengths_key = attn_metadata.seq_lens + if self.enable_dsa_cp: + need_gather_q_kv = False + # Inputs and outputs may be padded for CUDA graphs + num_input_tokens = attn_metadata.num_input_tokens + output_padded = output - For PD mix stage: - - Use original TP o_proj weight for decode phase - - Need full-gather o_proj weight from all TP ranks for prefill phase - """ - if AscendSFAImpl.o_proj_full_pool is None: - sample = self.o_proj.weight - AscendSFAImpl.o_proj_full_pool = torch.empty( - (sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device + # all-gather o_proj weight for prefill stage of PD mix node + o_proj_full_handle = None + # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj + # weight for prefill stage. + should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in { + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding, + } + + if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + need_gather_q_kv=need_gather_q_kv, + num_input_tokens=num_input_tokens, + ) + q, k = self.indexer_select_pre_process( + x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv + ) + else: + assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_no_split = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, ) + assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" + q_c = self.q_a_layernorm(q_c) + # Process for Flash Comm V1 + if need_gather_q_kv: + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + kv_no_split.contiguous(), need_gather_q_kv + ) - # Save TP-mode parameters (original sharded weights) - self.o_proj_tp_weight = self.o_proj.weight.clone().detach() - self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach() - self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach() - self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach() + q, k = self.indexer_select_pre_process( + x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv + ) - # Initially switch to TP mode for graph capture - self.o_proj.weight.set_(self.o_proj_tp_weight) - self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) + wait_for_kv_layer_from_connector(layer_name) - # Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks - self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size) - self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size) - self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size) + slot_mapping = attn_metadata.slot_mapping + if self.enable_dsa_cp: + assert attn_metadata.dsa_cp_context is not None + slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp + actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query + actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key - def _handle_o_proj_weight_switch_and_forward( - self, - attn_output: torch.Tensor, - output: torch.Tensor, - o_proj_full_handle: torch.distributed.Work | None, - should_shard_weight: bool, - ) -> tuple[torch.Tensor, bool]: - """ - Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation. - """ - # Gather o_proj weight from all TP ranks for Full-mode computation - if should_shard_weight: - # Wait for the completion of o_proj weight all-gather operation - if o_proj_full_handle is not None: - o_proj_full_handle.wait() + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) - # Switch o_proj to Full-mode (gathered weight from all TP ranks) - self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) - self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset) + if self.enable_dsa_cp: + assert k_pe is not None + assert k_nope is not None + # support all_gather kv async for communication calculation overlap + fused_kv_no_split, kv_ag_handle = all_gather_async( + torch.cat( + [k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), k.view(-1, k.shape[-1])], + dim=1, + ), + get_tp_group(), + async_op=should_shard_weight, + ) - # Apply quantization method and execute forward computation - output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output) + ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) + q_pe = self.rope_single(q_pe, cos, sin) - # Switch o_proj back to TP-mode for subsequent decode operations - self.o_proj.weight.set_(self.o_proj_tp_weight) - self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) - self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) - self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) + if self.enable_dsa_cp: + if kv_ag_handle is not None: + kv_ag_handle.wait() - return output, False - else: - # For decode scenario: perform all-to-all communication on o_proj input activations - # Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim] - send = ( - attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) - .permute(1, 0, 2) - .reshape(-1, self.num_heads * self.v_head_dim) - ) + if self.enable_dsa_cp_prefill_only: + for layer in self.layer_sharding_kwargs or []: + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + elif should_shard_weight: + _, o_proj_full_handle = all_gather_async( + self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool + ) - attn_output = torch.empty_like(send) - torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) + if kv_cache is not None: + assert fused_kv_no_split is not None + k_pe, k_nope, k = fused_kv_no_split.split( + [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 + ) + slot_mapping = attn_metadata.slot_mapping.view(-1, 1) + torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope) + torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe) - return attn_output, True + k = self._get_full_kv(k, attn_metadata) + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) + ) # b, s, n, d - def _get_full_kv(self, k, attn_metadata): - return k + topk_indices = self.indexer_select_post_process( + x=hidden_states, + qr=q_c, + q=q, + k=k, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + cos=cos, + sin=sin, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + need_gather_q_kv=need_gather_q_kv, + ) - def exec_kv( - self, - kv_no_split: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - kv_cache: tuple, - slots: torch.Tensor, - attn_metadata: M, - ): - B = kv_no_split.shape[0] - N = self.num_kv_heads - S = 1 - # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA" + attn_output = self._execute_sparse_flash_attention_process( + ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ) - if self.enable_dsa_cp: - _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( - kv_no_split, - self.kv_a_layernorm.weight, # type: ignore[union-attr] - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] - cache_mode=cache_mode, - is_output_kv=True, - ) - return k_pe, k_nope - else: - torch_npu.npu_kv_rmsnorm_rope_cache( - kv_no_split, - self.kv_a_layernorm.weight, # type: ignore[union-attr] - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] - cache_mode=cache_mode, + attn_output = self._v_up_proj(attn_output) + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.o_proj.weight, + dependency=attn_output, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + linear_layer=self.o_proj, + ) + + if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only: + # When using SFA-CP with pd mixed, o_proj has two cases: + # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. + # 2. decode: all-to-all the hidden_state before the o_proj forward. + result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( + attn_output=attn_output, + output=output, + o_proj_full_handle=o_proj_full_handle, + should_shard_weight=should_shard_weight, ) - return None, None + if not require_o_proj_forward: + return result + attn_output = result - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = ( - self.q_proj(x)[0] - .view(-1, self.local_num_heads, self.qk_head_dim) - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - ) + output[...] = self.o_proj(attn_output)[0] - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe + maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) - def _v_up_proj(self, x): - num_input_tokens, _, _ = x.shape - if ( - x.dtype in [torch.float16, torch.bfloat16] - and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") - and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS - ): - x = x.view(-1, self.local_num_heads, self.kv_lora_rank) - res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) - torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) - x = res.reshape(-1, self.local_num_heads * self.v_head_dim) - else: - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) - # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) - return x + return output_padded - def _sfa_preprocess_with_mlapo( - self, - hidden_states: torch.Tensor, - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], - cos: torch.Tensor, - sin: torch.Tensor, - slot_mapping: torch.Tensor, - num_input_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - k_nope, k_pe = kv_cache[0], kv_cache[1] - ql_nope = torch.empty( - (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - q_pe = torch.empty( - (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - q_c = torch.empty( - (num_input_tokens, self.q_lora_rank), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - torch.ops._C_ascend.mla_preprocess( - hidden_states, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - cos, - sin, - self.W_UK_T, - k_nope, - k_pe, - slot_mapping, - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=self.ctkv_scale, - q_nope_scale=self.q_nope_scale, - cache_mode="krope_ctkv", - quant_mode="per_tensor_quant_asymm", - enable_inner_out=True, - q_out0=ql_nope, - kv_cache_out0=k_nope, - q_out1=q_pe, - kv_cache_out1=k_pe, - inner_out=q_c, + def _execute_sparse_flash_attention_process( + self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ): + block_table = attn_metadata.block_table + kv = kv_cache[0] + key_rope = kv_cache[1] + + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( + query=ql_nope, + key=kv, + value=kv, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=block_table, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_key, + query_rope=q_pe, + key_rope=key_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, ) - return hidden_states, ql_nope, q_pe, q_c + return attn_output def indexer_select_pre_process( self, x: torch.Tensor, + qr: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + need_gather_q_kv: bool = False, ): - k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] - k_li = self.k_norm(k_li).unsqueeze(1) - k_li = k_li.view(-1, 1, self.head_dim) + k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(k_proj, need_gather_q_kv) + k = self.k_norm(k_proj).unsqueeze(1) + k = k.view(-1, 1, self.head_dim) if HAS_TRITON: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) - k_li = rope_forward_triton_siso( - k_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + q, k = rope_forward_triton( + q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style ) else: - k_li_pe, k_li_nope = torch.split( - k_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 - ) + k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - k_li_pe = k_li_pe.unsqueeze(2) - k_li_pe = torch_npu.npu_interleave_rope(k_li_pe, cos, sin) - k_li_pe = k_li_pe.squeeze(2) + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) - k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128] + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + q = None - return k_li + return q, k def indexer_select_post_process( self, x: torch.Tensor, - q_c: torch.Tensor, + qr: torch.Tensor, + q: torch.Tensor | None, + k: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, actual_seq_lengths_query: torch.Tensor, actual_seq_lengths_key: torch.Tensor, + need_gather_q_kv: bool = False, ): - weights, _ = self.weights_proj(x) + if q is None: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos_q, sin_q = cos, sin - q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - if HAS_TRITON: - q_li = rope_forward_triton_siso( - q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style - ) - else: - q_li_pe, q_li_nope = torch.split( - q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 ) # [b,s,64,64+64] - q_li_pe = q_li_pe.unsqueeze(2) - q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin) - q_li_pe = q_li_pe.squeeze(2) - q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + if kv_cache is not None: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) + ) # b, s, n, d + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() + + weights, _ = self.weights_proj(x) + weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) + + key = kv_cache[2] + block_table = attn_metadata.block_table # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. # So two branches are maintained temporarily. # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. if self.use_torch_npu_lightning_indexer: topk_indices, _ = torch_npu.npu_lightning_indexer( - query=q_li, - key=kv_cache[2], + query=q, + key=key, weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, - block_table=attn_metadata.block_table, + block_table=block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, @@ -919,12 +1029,12 @@ def indexer_select_post_process( ) else: topk_indices = torch.ops._C_ascend.npu_lightning_indexer( - query=q_li, - key=kv_cache[2], + query=q, + key=key, weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, - block_table=attn_metadata.block_table, + block_table=block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, @@ -932,212 +1042,101 @@ def indexer_select_post_process( ) return topk_indices - def _execute_sparse_flash_attention_process( - self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key - ): - block_table = attn_metadata.block_table - kv = kv_cache[0] - key_rope = kv_cache[1] - - attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( - query=ql_nope, - key=kv, - value=kv, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=block_table, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_kv=actual_seq_lengths_key, - query_rope=q_pe, - key_rope=key_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, - ) - return attn_output - - def forward( - self, - layer_name, - hidden_states: torch.Tensor, # query in unified attn - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, - need_gather_q_kv: bool = False, - output: torch.Tensor | None = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - forward_context = get_forward_context() - if attn_metadata is None: - # Profiling run. - if self.enable_dsa_cp_with_layer_shard and not forward_context.in_profile_run: - for layer in self.layer_sharding_kwargs or []: - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) - return output.fill_(0) - - cos = attn_metadata.cos - sin = attn_metadata.sin - slot_mapping = attn_metadata.slot_mapping - slot_mapping_cp = None - if self.enable_dsa_cp: - assert attn_metadata.dsa_cp_context is not None - slot_mapping_cp = attn_metadata.dsa_cp_context.slot_mapping_cp - actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query - actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key - else: - actual_seq_lengths_query = attn_metadata.cum_query_lens - actual_seq_lengths_key = attn_metadata.seq_lens - - # Inputs and outputs may be padded for CUDA graphs - num_input_tokens = attn_metadata.num_input_tokens - output_padded = output - - # all-gather o_proj weight for prefill stage of PD mix node - o_proj_full_handle = None - # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj - # weight for prefill stage. - full_gather_o_proj_enabled = self.enable_dsa_cp_with_o_proj_tp and attn_metadata.attn_state not in { - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding, - } + def _init_o_proj_tp_full_params(self): + """ + Initialize TP-mode and Full-mode parameters for o_proj weight, + preparing for weight switching in PD mix stage. - # run mlapo ops when dsa-cp is disabled, and ensure that num_tokens satisfies the count limitation - if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: - hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_with_mlapo( - hidden_states=hidden_states, - kv_cache=kv_cache, - cos=cos, - sin=sin, - slot_mapping=slot_mapping, - num_input_tokens=num_input_tokens, - ) - k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) - # native - else: - assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." - weight_prefetch_method = get_weight_prefetch_method() - weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( - inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states - ) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_no_split = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, + For PD mix stage: + - Use original TP o_proj weight for decode phase + - Need full-gather o_proj weight from all TP ranks for prefill phase + """ + if AscendSFAImpl.o_proj_full_pool is None: + sample = self.o_proj.weight + AscendSFAImpl.o_proj_full_pool = torch.empty( + (sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device ) - assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" - q_c = self.q_a_layernorm(q_c) - k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) - - wait_for_kv_layer_from_connector(layer_name) - - if self.enable_dsa_cp: - assert slot_mapping_cp is not None - k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping_cp, attn_metadata) - else: - k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) - - if self.enable_dsa_cp: - assert k_pe is not None - assert k_nope is not None - async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled - # support all_gather kv async for communication calculation overlap - fused_kv_no_split, kv_ag_handle = all_gather_async( - torch.cat( - [ - k_pe.view(-1, k_pe.shape[-1]), - k_nope.view(-1, k_nope.shape[-1]), - k_li.view(-1, k_li.shape[-1]), - ], - dim=1, - ), - get_tp_group(), - async_op=async_op, - ) - - ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) - q_pe = self.rope_single(q_pe, cos, sin) - - if self.enable_dsa_cp: - if kv_ag_handle is not None: - kv_ag_handle.wait() - - if self.enable_dsa_cp_with_layer_shard: - for layer in self.layer_sharding_kwargs or []: - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) - elif full_gather_o_proj_enabled: - _, o_proj_full_handle = all_gather_async( - self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool - ) + # Save TP-mode parameters (original sharded weights) + self.o_proj_tp_weight = self.o_proj.weight.clone().detach() + self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach() + self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach() + self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach() - if kv_cache is not None: - assert fused_kv_no_split is not None - k_pe, k_nope, k_li = fused_kv_no_split.split( - [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 - ) - k_nope = k_nope.view(k_nope.shape[0], 1, -1) - k_pe = k_pe.view(k_pe.shape[0], 1, -1) - DeviceOperator.reshape_and_cache( - key=k_nope[: attn_metadata.num_actual_tokens], - value=k_pe[: attn_metadata.num_actual_tokens], - key_cache=kv_cache[0], - value_cache=kv_cache[1], - slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens], - ) + # Initially switch to TP mode for graph capture + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) - k_li = self._get_full_kv(k_li, attn_metadata) + # Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks + self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size) + self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size) + self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size) - if kv_cache is not None: - if self.is_kv_producer: - attn_metadata.reshape_cache_event = torch.npu.Event() - torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1]) - ) # b, s, n, d - if self.is_kv_producer: - attn_metadata.reshape_cache_event.record() + def _handle_o_proj_weight_switch_and_forward( + self, + attn_output: torch.Tensor, + output: torch.Tensor, + o_proj_full_handle: torch.distributed.Work | None, + should_shard_weight: bool, + ) -> tuple[torch.Tensor, bool]: + """ + Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation. + """ + # Gather o_proj weight from all TP ranks for Full-mode computation + if should_shard_weight: + # Wait for the completion of o_proj weight all-gather operation + if o_proj_full_handle is not None: + o_proj_full_handle.wait() - topk_indices = self.indexer_select_post_process( - x=hidden_states, - q_c=q_c, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - cos=cos, - sin=sin, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - ) + # Switch o_proj to Full-mode (gathered weight from all TP ranks) + self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) + self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset) - attn_output = self._execute_sparse_flash_attention_process( - ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key - ) + # Apply quantization method and execute forward computation + output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output) - attn_output = self._v_up_proj(attn_output) - weight_prefetch_method = get_weight_prefetch_method() - weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( - inputs=self.o_proj.weight, - dependency=attn_output, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - linear_layer=self.o_proj, - ) + # Switch o_proj back to TP-mode for subsequent decode operations + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) - if self.enable_dsa_cp_with_o_proj_tp: - # When using SFA-CP with pd mixed, o_proj has two cases: - # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. - # 2. decode: all-to-all the hidden_state before the o_proj forward. - result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( - attn_output=attn_output, - output=output, - o_proj_full_handle=o_proj_full_handle, - should_shard_weight=full_gather_o_proj_enabled, + return output, False + else: + # For decode scenario: perform all-to-all communication on o_proj input activations + # Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim] + send = ( + attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) + .permute(1, 0, 2) + .reshape(-1, self.num_heads * self.v_head_dim) ) - if not require_o_proj_forward: - return result - attn_output = result - output[...] = self.o_proj(attn_output)[0] + attn_output = torch.empty_like(send) + torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) - maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) + return attn_output, True - return output_padded + def forward_mha( + self, + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: M, + k_scale: torch.Tensor, + output: torch.Tensor, + ) -> None: + raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.") + + def forward_mqa( + self, + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: M, + layer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.") diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index 909065179a6..ad863e40910 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -146,79 +146,6 @@ def _triton_rope( tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) -@triton.jit -def _triton_rope_siso( - qk_ptr, - qk_row_stride, - cos_ptr, - cos_row_stride, - sin_ptr, - sin_row_stride, - cos_sin_ptr, - cos_sin_row_stride, - pos_ptr, - num_tokens, - n_h: tl.constexpr, - hd: tl.constexpr, - rope_dim: tl.constexpr, - pad_n_h: tl.constexpr, - pad_rope_dim: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - IS_NEOX_STYLE: tl.constexpr, - USE_COS_SIN: tl.constexpr, -): - pid = tl.program_id(0).to(tl.int64) - row_block_size = tl.num_programs(0) - - for row_idx in tl.range(pid, num_tokens, row_block_size): - qk_start_ptr = qk_ptr + row_idx * qk_row_stride - - # #################################################################### - # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position - # m of this program instance - # #################################################################### - cos_offsets = tl.arange(0, pad_rope_dim // 2) - sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim) - cos_mask = cos_offsets < (rope_dim // 2) - if USE_COS_SIN: - pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) - cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride - cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) - sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32) - else: - cos_start_ptr = cos_ptr + row_idx * cos_row_stride - sin_start_ptr = sin_ptr + row_idx * sin_row_stride - cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) - sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) - - # #################################################################### - # Load the left and right half of q and k for the current - # program instance (i.e. for the current token) separately - # #################################################################### - # left half of the head - if IS_NEOX_STYLE: - first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :] - else: - first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :]) - - first_mask = (tl.arange(0, pad_n_h)[:, None] < n_h) & ( - tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2) - ) - qk_tile_1 = tl.load(qk_start_ptr + first_half_offsets, mask=first_mask, other=0).to(sin_row.dtype) - - # right half of the head - if IS_NEOX_STYLE: - second_half_offsets = first_half_offsets + (rope_dim // 2) - else: - second_half_offsets = first_half_offsets + 1 - second_mask = first_mask - qk_tile_2 = tl.load(qk_start_ptr + second_half_offsets, mask=second_mask, other=0).to(sin_row.dtype) - - # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] - new_qk_tile_1 = qk_tile_1 * cos_row - qk_tile_2 * sin_row - tl.store(qk_start_ptr + first_half_offsets, new_qk_tile_1, mask=first_mask) - - def rope_forward_triton( q: torch.Tensor, k: torch.Tensor, @@ -310,83 +237,3 @@ def rope_forward_triton( "Please check whether you call rope_forward_triton correctly." ) return q, k - - -def rope_forward_triton_siso( - qk: torch.Tensor, - cos: torch.Tensor = None, - sin: torch.Tensor = None, - cos_sin_cache: torch.Tensor = None, - positions: torch.Tensor = None, - rope_dim: int = -1, - is_neox_style: bool = True, -) -> tuple[torch.Tensor, torch.Tensor]: - if not qk.is_contiguous(): - qk = qk.contiguous() - - num_tokens, n_head, head_dim = qk.shape - assert rope_dim <= head_dim - pad_rope_dim = triton.next_power_of_2(rope_dim) - pad_n_head = triton.next_power_of_2(n_head) - BLOCK_SIZE = pad_n_head - num_vectorcore = get_vectorcore_num() - n_row = min(num_tokens, num_vectorcore) - - if cos_sin_cache is not None and positions is not None: - assert positions.shape[0] == num_tokens - _triton_rope_siso[(n_row,)]( - qk, - qk.stride(0), - None, - None, - None, - None, - cos_sin_cache, - cos_sin_cache.stride(0), - positions, - num_tokens, - n_head, - head_dim, - rope_dim, - pad_n_head, - pad_rope_dim, - BLOCK_SIZE=BLOCK_SIZE, - IS_NEOX_STYLE=is_neox_style, - USE_COS_SIN=True, - ) - elif cos is not None and sin is not None: - assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens - cos = cos.view(num_tokens, -1) - sin = sin.view(num_tokens, -1) - if rope_dim == -1: - # If rope_dim is not specified, we assume that input cos/sin is not - # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 - rope_dim = cos.shape[-1] * 2 - _triton_rope_siso[(n_row,)]( - qk, - qk.stride(0), - cos, - cos.stride(0), - sin, - sin.stride(0), - None, - None, - None, - num_tokens, - n_head, - head_dim, - rope_dim, - pad_n_head, - pad_rope_dim, - BLOCK_SIZE=BLOCK_SIZE, - IS_NEOX_STYLE=is_neox_style, - USE_COS_SIN=False, - ) - else: - raise ValueError( - "Currently, rope_forward_triton supports passing:\n" - "1. positions and original cos_sin_cache.\n" - "2. cos and sin which are already selected by positions\n" - "Please check whether you call rope_forward_triton correctly." - ) - return qk diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d74df6dd851..792af587555 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1114,24 +1114,10 @@ def enable_dsa_cp_with_layer_shard() -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() - # because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be - # effectively hidden, it is enabled only during the prefill stage. is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer return is_prefill_instance -@lru_cache(maxsize=1) -def enable_dsa_cp_with_o_proj_tp() -> bool: - if not enable_dsa_cp(): - return False - from vllm.config import get_current_vllm_config - - vllm_config = get_current_vllm_config() - # if is PD mix stage, using original TP o_proj weight, and also need to - # full gather for o_proj weight for prefill stage. - return vllm_config.kv_transfer_config is None - - def check_gdn_layer(vllm_config) -> bool: """ gdn layer is marked with `linear_attention`.