diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 6588686eb57..ea1a35630f4 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -6,12 +6,11 @@ import vllm.envs as envs_vllm from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger -from vllm.model_executor.layers.linear import (ReplicatedLinear, - UnquantizedLinearMethod) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import MLACommonMetadataBuilder from vllm.v1.attention.backends.utils import AttentionCGSupport @@ -24,6 +23,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) +from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.shared_weight_layer import ( is_hidden_layer, post_process_after_loading_for_shared_weight_series, @@ -33,7 +33,8 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, - enable_sp, maybe_trans_nz, replace_layer) + enable_dsa_cp, enable_dsa_cp_with_shard, + maybe_trans_nz) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: @@ -148,14 +149,14 @@ def __init__( got {self.decode_threshold}" self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.enable_sfa_cp = enable_sp() and \ - hasattr(self.model_config.hf_config, "index_topk") - assert not ( - self.enable_sfa_cp - and self.vllm_config.compilation_config.cudagraph_mode - == CUDAGraphMode.FULL_DECODE_ONLY - ), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1." + self.enable_sfa_cp = enable_dsa_cp() + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, + dtype=torch.int32, + device=device) + self.actual_seq_lengths_key = torch.empty_like( + self.actual_seq_lengths_query) @classmethod def get_cudagraph_support( @@ -211,7 +212,6 @@ def build( pad_size = num_tokens_pad - cos.shape[0] assert cos.shape == sin.shape, \ f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}" - if pad_size > 0: cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size)) sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size)) @@ -221,12 +221,16 @@ def build( slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size_slot), value=-1) + slot_mapping_cp = torch.full(size=(num_tokens_per_device, ), + fill_value=-1, + dtype=slot_mapping.dtype, + device=slot_mapping.device) else: slot_mapping = slot_mapping[:num_tokens_pad] + slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] cos = cos[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad] - slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] assert cos.shape[0] == num_tokens_per_device, \ f"cos.shape[0] must be equal to num_tokens_per_device, \ got {cos.shape[0]} and {num_tokens_per_device}" @@ -237,8 +241,9 @@ def build( f"slot_mapping.shape[0] must be equal to num_tokens_pad, \ got {slot_mapping.shape[0]} and {num_tokens_pad}" - actual_seq_lengths_query = torch.empty_like(cum_query_lens) - actual_seq_lengths_key = torch.empty_like(seq_lens) + actual_seq_lengths_query = self.actual_seq_lengths_query + actual_seq_lengths_key = self.actual_seq_lengths_key + num_segs = cum_query_lens.shape[0] last_token = 0 cum = 0 @@ -247,20 +252,23 @@ def build( global_end = cum_query_lens[i].item() last_token = global_end - local_start = max(global_start, local_start) - local_end = min(global_end, local_end_with_pad) - num_local_tokens = local_end - local_start + req_local_start = max(global_start, local_start) + req_local_end = min(global_end, local_end_with_pad) + num_local_tokens = req_local_end - req_local_start if num_local_tokens > 0: cum += num_local_tokens actual_seq_lengths_query[i] = cum - offset = global_end - local_end + offset = global_end - req_local_end actual_seq_lengths_key[i] = seq_lens[i].item() - offset else: actual_seq_lengths_query[i] = cum actual_seq_lengths_key[i] = 0 + actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs] + actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs] + sfa_cp_context = SfaCpContext( num_tokens=num_tokens, num_tokens_pad=num_tokens_pad, @@ -310,6 +318,7 @@ def build_for_graph_capture( class AscendSFAImpl(MLAAttentionImpl): + o_proj_full_pool: Optional[torch.Tensor] = None """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -365,28 +374,27 @@ def __init__( assert self.indexer is not None, "Indexer is required for DSA." - self.enable_sfa_cp = enable_sp() + self.enable_sfa_cp = enable_dsa_cp() + self.enable_sfa_cp_with_shard = enable_dsa_cp_with_shard() self.local_num_heads = self.num_heads self.vllm_config = get_current_vllm_config() if self.enable_sfa_cp: self.local_num_heads = self.num_heads * self.tp_size - - #TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97 - self._replace_linear_class_for_sfa_cp() - from vllm_ascend.distributed.parallel_state import \ - get_shared_weight_group - if is_hidden_layer(self.vllm_config, self.q_proj): - register_layer_to_shared_weight_series( - series_name="q_proj", - group=get_shared_weight_group(), - layer=self.q_proj, - prefetch_step=1) - if is_hidden_layer(self.vllm_config, self.o_proj): - register_layer_to_shared_weight_series( - series_name="o_proj", - group=get_shared_weight_group(), - layer=self.o_proj, - prefetch_step=1) + if self.enable_sfa_cp_with_shard: + from vllm_ascend.distributed.parallel_state import \ + get_shared_weight_group + if is_hidden_layer(self.vllm_config, self.q_proj): + register_layer_to_shared_weight_series( + series_name="q_proj", + group=get_shared_weight_group(), + layer=self.q_proj, + prefetch_step=1) + if is_hidden_layer(self.vllm_config, self.o_proj): + register_layer_to_shared_weight_series( + series_name="o_proj", + group=get_shared_weight_group(), + layer=self.o_proj, + prefetch_step=1) # indexer param self.n_head: int = self.indexer.n_head # 64 @@ -433,12 +441,45 @@ def process_weights_after_loading(self, act_dtype: torch.dtype): dispose_layer(self.kv_b_proj) if self.enable_sfa_cp: - if is_hidden_layer(self.vllm_config, self.q_proj): - post_process_after_loading_for_shared_weight_series( - self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - post_process_after_loading_for_shared_weight_series( - self.o_proj) + # if prefill-only and enable sfa cp with shard, we need to post-process shared weight here. + if self.enable_sfa_cp_with_shard: + if is_hidden_layer(self.vllm_config, self.q_proj): + post_process_after_loading_for_shared_weight_series( + self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + post_process_after_loading_for_shared_weight_series( + self.o_proj) + else: + # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage. + if AscendSFAImpl.o_proj_full_pool is None: + sample = self.o_proj.weight + AscendSFAImpl.o_proj_full_pool = torch.empty( + (sample.shape[0] * self.tp_size, sample.shape[1]), + dtype=sample.dtype, + device=sample.device) + # we should save parameters for tp mode + self.o_proj_tp_weight = self.o_proj.weight.clone().detach() + self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone( + ).detach() + self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone( + ).detach() + self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone( + ).detach() + # initially switch to tp mode for graph capture + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_( + self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_( + self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_( + self.o_proj_tp_aclnn_input_offset) + # wo should also save parameters for full mode + self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat( + self.tp_size) + self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat( + self.tp_size) + self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat( + self.tp_size) if self.enable_mlapo: quant_method = getattr( @@ -513,7 +554,6 @@ def exec_kv( sin: torch.Tensor, kv_cache: Tuple, slots: torch.Tensor, - slots_cp: Optional[torch.Tensor], ): B = kv_no_split.shape[0] N = self.num_kv_heads @@ -524,30 +564,19 @@ def exec_kv( cache_mode = "PA" if self.enable_sfa_cp: - assert slots_cp is not None _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, cos, sin, - slots_cp.to(torch.int64), + slots.to(torch.int64), kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, cache_mode=cache_mode, is_output_kv=True, ) - #TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 - k_pe = get_tp_group().all_gather(k_pe, 0) - k_nope = get_tp_group().all_gather(k_nope, 0) - - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_( - kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1), - k_nope.view(-1, k_nope.shape[-1])) - torch_npu.npu_scatter_nd_update_( - kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1), - k_pe.view(-1, k_pe.shape[-1])) + return k_pe, k_nope else: torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, @@ -560,6 +589,7 @@ def exec_kv( epsilon=self.kv_a_layernorm.variance_epsilon, cache_mode=cache_mode, ) + return None, None def rope_single( self, @@ -743,10 +773,11 @@ def forward( if attn_metadata is None: # Profiling run. if self.enable_sfa_cp and not forward_context.in_profile_run: - if is_hidden_layer(self.vllm_config, self.q_proj): - reach_layer_for_shared_weight_series(self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) + if self.enable_sfa_cp_with_shard: + if is_hidden_layer(self.vllm_config, self.q_proj): + reach_layer_for_shared_weight_series(self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) return output.fill_(0) has_prefill = attn_metadata.has_prefill cos = attn_metadata.cos @@ -757,6 +788,12 @@ def forward( need_gather_q_kv = False # Inputs and outputs may be padded for CUDA graphs output_padded = output + # all-gather o_proj weight for prefill stage of PD mix node + o_proj_full_work = None + # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj weight for prefill stage. + should_shard_weight = self.enable_sfa_cp_with_shard or attn_metadata.attn_state not in { + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + } # TODO(zzzzwwjj): In sfa, prefill and decode have the same calculation formula, # so `has_prefill` here is not necessary. @@ -768,6 +805,13 @@ def forward( need_gather_q_kv=need_gather_q_kv, num_input_tokens=attn_metadata.num_input_tokens, ) + # split indexer_select into `necessary` and `optional` process, such that the optional process can be skipped in short sequence case. + q, k = self.indexer_select_necessary_process( + x=hidden_states, + qr=q_c, + cos=cos, + sin=sin, + need_gather_q_kv=need_gather_q_kv) else: assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, @@ -786,32 +830,82 @@ def forward( kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( kv_no_split.contiguous(), need_gather_q_kv) + # Early indexer_select k for communication overlap + q, k = self.indexer_select_necessary_process( + x=hidden_states, + qr=q_c, + cos=cos, + sin=sin, + need_gather_q_kv=need_gather_q_kv) + if has_prefill: wait_for_kv_layer_from_connector(layer_name) slot_mapping = attn_metadata.slot_mapping - slot_mapping_cp = None if self.enable_sfa_cp: assert attn_metadata.sfa_cp_context is not None - slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp + slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key - self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, - slot_mapping_cp) + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, + slot_mapping) - if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None: - if is_hidden_layer(self.vllm_config, self.q_proj): - reach_layer_for_shared_weight_series(self.q_proj) - if is_hidden_layer(self.vllm_config, self.o_proj): - reach_layer_for_shared_weight_series(self.o_proj) + ag_no_split = None # all-gather k_pe and k_nope and indexer's k for communication overlap + ag_work = None # async work handle + if self.enable_sfa_cp: + assert k_pe is not None + assert k_nope is not None + ag_no_split, ag_work = all_gather_async( + torch.cat([ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + k.view(-1, k.shape[-1]) + ], + dim=1), + get_tp_group(), + async_op=should_shard_weight) # only prefill stage need ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) - topk_indices = self.indexer_select( + if self.enable_sfa_cp: + if ag_work is not None: + ag_work.wait() + if self.enable_sfa_cp_with_shard: # broadcast o_proj weight for prefill Node + if is_hidden_layer(self.vllm_config, self.q_proj): + reach_layer_for_shared_weight_series(self.q_proj) + if is_hidden_layer(self.vllm_config, self.o_proj): + reach_layer_for_shared_weight_series(self.o_proj) + elif should_shard_weight: # all-gather o_proj weight for prefill stage of PD mix node + _, o_proj_full_work = all_gather_async( + self.o_proj_tp_weight, + get_tp_group(), + output=AscendSFAImpl.o_proj_full_pool) + if kv_cache is not None: + assert ag_no_split is not None + k_pe, k_nope, k = ag_no_split.split([ + self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim + ], + dim=-1) + slot_mapping = attn_metadata.slot_mapping.view(-1, 1) + torch_npu.npu_scatter_nd_update_( + kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, + k_nope) + torch_npu.npu_scatter_nd_update_( + kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, + k_pe) + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + topk_indices = self.indexer_select_optional_process( x=hidden_states, qr=q_c, + q=q, kv_cache=kv_cache, attn_metadata=attn_metadata, cos=cos, @@ -840,25 +934,54 @@ def forward( dependency=attn_output, max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) + + if self.enable_sfa_cp and not self.enable_sfa_cp_with_shard: + # Gather o_proj weight from all tp ranks + if should_shard_weight: + # wait for all-gather o_proj weight + if o_proj_full_work is not None: + o_proj_full_work.wait() + # switch o_proj into full mode + self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) + self.o_proj.aclnn_input_scale.set_( + self.o_proj_full_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_( + self.o_proj_full_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_( + self.o_proj_full_aclnn_input_offset) + # apply o_proj quant method + output[...] = self.o_proj.quant_method.quant_method.apply( + self.o_proj, attn_output) + # switch o_proj back to tp mode + self.o_proj.weight.set_(self.o_proj_tp_weight) + self.o_proj.aclnn_input_scale.set_( + self.o_proj_tp_aclnn_input_scale) + self.o_proj.aclnn_input_scale_reciprocal.set_( + self.o_proj_tp_aclnn_input_scale_reciprocal) + self.o_proj.aclnn_input_offset.set_( + self.o_proj_tp_aclnn_input_offset) + return output_padded + else: + # Alltoall for o_proj input activations in decode scenario + send = attn_output.view( + -1, self.tp_size, + self.num_heads * self.v_head_dim).permute(1, 0, 2).reshape( + -1, self.num_heads * self.v_head_dim) + attn_output = torch.empty_like(send) + torch.distributed.all_to_all_single( + attn_output, send, group=get_tp_group().device_group) + output[...] = self.o_proj(attn_output)[0] return output_padded - def indexer_select( + def indexer_select_necessary_process( self, x: torch.Tensor, qr: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, - actual_seq_lengths_query: torch.Tensor, - actual_seq_lengths_key: torch.Tensor, need_gather_q_kv: bool = False, ): - # q process in new stream - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( k_proj, need_gather_q_kv) @@ -866,6 +989,9 @@ def indexer_select( k = k.view(-1, 1, self.head_dim) if HAS_TRITON: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) q, k = rope_forward_triton(q, @@ -875,6 +1001,37 @@ def indexer_select( rope_dim=self.qk_rope_head_dim, is_neox_style=True) else: + k_pe, k_nope = torch.split( + k, + [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + q = None + + return q, k + + def indexer_select_optional_process( + self, + x: torch.Tensor, + qr: torch.Tensor, + q: Optional[torch.Tensor], + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + cos: torch.Tensor, + sin: torch.Tensor, + actual_seq_lengths_query: torch.Tensor, + actual_seq_lengths_key: torch.Tensor, + need_gather_q_kv: bool = False, + ): + if q is None: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos_q, sin_q = cos, sin cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) @@ -889,33 +1046,11 @@ def indexer_select( q_pe = q_pe.squeeze(2) q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - k_pe, k_nope = torch.split( - k, - [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64+64] - - k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) - k_pe = k_pe.squeeze(2) - - k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] - - if self.enable_sfa_cp: - k = get_tp_group().all_gather(k, 0) - - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), - attn_metadata.slot_mapping.view( - -1, 1), - k.view(-1, - k.shape[-1])) # b, s, n, d - weights, _ = self.weights_proj(x) weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( weights, need_gather_q_kv) block_table = attn_metadata.block_tables - topk_indices = torch.ops._C_ascend.npu_lightning_indexer( query=q, key=kv_cache[2], @@ -928,42 +1063,3 @@ def indexer_select( sparse_count=2048, sparse_mode=3) return topk_indices - - def _replace_linear_class_for_sfa_cp(self): - - vllm_config = get_current_vllm_config() - # Dispose tensor from the original q_proj - dispose_layer(self.q_proj) - # Construct the new q_proj using ReplicatedLinear - new_q_proj = ReplicatedLinear(self.q_lora_rank, - self.local_num_heads * self.qk_head_dim, - bias=False, - quant_config=vllm_config.quant_config, - prefix=self.q_proj.prefix) - # Replace the q_proj with the new one - replace_layer(self.q_proj, new_q_proj) - - # Dispose tensor from the original kv_b_proj - dispose_layer(self.kv_b_proj) - # Construct the new kv_b_proj using ReplicatedLinear - new_kv_b_proj = ReplicatedLinear( - self.kv_lora_rank, - self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=vllm_config.quant_config, - prefix=self.kv_b_proj.prefix) - # Replace the kv_b_proj with the new one - replace_layer(self.kv_b_proj, new_kv_b_proj) - - # Dispose tensor from the original o_proj - dispose_layer(self.o_proj) - # Construct the new o_proj using ReplicatedLinear - config = vllm_config.model_config.hf_config - new_o_proj = ReplicatedLinear(config.num_attention_heads * - config.v_head_dim, - config.hidden_size, - bias=False, - quant_config=vllm_config.quant_config, - prefix=self.o_proj.prefix) - # Replace the o_proj with the new one - replace_layer(self.o_proj, new_o_proj) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index e886a31113e..02e21932822 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -8,7 +8,7 @@ init_model_parallel_group) from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import (enable_sp, flashcomm2_enable, +from vllm_ascend.utils import (enable_dsa_cp_with_shard, flashcomm2_enable, flashcomm2_o_shared_enabled) # Currently, mc2 op need their own group coordinator. @@ -37,7 +37,6 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() backend = torch.distributed.get_backend(get_world_group().device_group) - vllm_config = get_current_vllm_config() global_tp_size = parallel_config.tensor_parallel_size global_dp_size = parallel_config.data_parallel_size global_pp_size = parallel_config.pipeline_parallel_size @@ -166,10 +165,10 @@ def _create_shared_weight_group(group_name: str) -> GroupCoordinator: group_name=group_name) global _SHARED_WEIGHT - # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97 - is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") - if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None: + + if enable_dsa_cp_with_shard(): _SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight") + # TODO: Extract and unify the logic across different communication group. if flashcomm2_enable(): flashcomm2_otp_size = get_ascend_config( @@ -225,9 +224,7 @@ def _create_shared_weight_group(group_name: str) -> GroupCoordinator: # Create shared weight group for flashcomm2 oproj if flashcomm2_o_shared_enabled(): assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1" - if _SHARED_WEIGHT is None: - _SHARED_WEIGHT = _create_shared_weight_group( - "flashcomm2_o_shared") + _SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared") if get_ascend_config().multistream_overlap_gate: global _FC3_QUANT_X diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index 6b4b894e580..cd0999b4cbe 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -1,7 +1,9 @@ import os +from typing import Optional import torch import torch.distributed as dist +from vllm.distributed.parallel_state import GroupCoordinator from vllm.forward_context import get_forward_context from vllm_ascend.distributed.parallel_state import (get_dp_group, @@ -90,3 +92,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor: offset += num_tokens_dp x = result return x + + +def all_gather_async(input: torch.Tensor, + group: GroupCoordinator, + output: Optional[torch.Tensor] = None, + async_op: bool = True): + if group.world_size == 1: + return input, None + if output is None: + input_size = input.size() + output_size = (input_size[0] * group.world_size, ) + input_size[1:] + output = torch.empty(output_size, + dtype=input.dtype, + device=input.device) + return output, dist.all_gather_into_tensor(output, + input, + group=group.device_group, + async_op=async_op) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 674dab54e0c..f3fba77480e 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -38,6 +38,7 @@ import re from functools import lru_cache +from types import SimpleNamespace from typing import Optional, Union import torch @@ -59,7 +60,8 @@ get_flashcomm2_otp_group, get_mlp_tp_group, get_otp_group) -from vllm_ascend.utils import (enable_sp, flashcomm2_enable, +from vllm_ascend.utils import (enable_dsa_cp, enable_dsa_cp_with_shard, + enable_sp, flashcomm2_enable, get_flashcomm2_reorgnized_batch_ids, matmul_allreduce_enable, mlp_tp_enable, oproj_tp_enable, shared_expert_dp_enabled) @@ -538,7 +540,8 @@ def matmul_and_reduce(self, input_parallel: torch.Tensor, return tensor_model_parallel_all_reduce(output_parallel) pad_size = forward_context.pad_size - if pad_size > 0: + if pad_size > 0 and not (enable_dsa_cp() + and "o_proj" in self.layer.prefix): x = F.pad(x, (0, 0, 0, pad_size)) world_size = self.layer.tp_size @@ -609,9 +612,65 @@ def update_attrs(self): self.unique_prefix = self.layer.unique_prefix +class ShardedCPRowParallelOp(CustomRowParallelOp): + # Initialize `RowParallelLinear` as a replicated linear layer. + def __init__(self, layer): + super().__init__(layer) + + @property + def comm_group(self): + # fake comm group to bypass tp logic + return SimpleNamespace(world_size=1, + rank_in_group=0, + device_group=None) + + def apply_impl( + self, + input_, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + # Matrix multiply. + assert self.quant_method is not None + output = self.quant_method.apply(self.layer, input_, bias) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def update_attrs(self): + super().update_attrs() + self.reduce_results = False + + +class ShardedCPColumnParallelOp(CustomColumnParallelOp): + # Initialize `ColumnParallelLinear` as a replicated linear layer. + @property + def comm_group(self): + # fake comm group to bypass tp logic + return SimpleNamespace(world_size=1, + rank_in_group=0, + device_group=None) + + def apply_impl( + self, + input_, + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + # Matrix multiply. + assert self.quant_method is not None + output = self.quant_method.apply(self.layer, input_, bias) + output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output + return output, output_bias + + def _get_column_parallel_op( - prefix, layer -) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp]]: + prefix, layer +) -> Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, + ShardedCPColumnParallelOp]]: + if enable_dsa_cp() and ("q_b_proj" in prefix or "kv_b_proj" in prefix): + return ShardedCPColumnParallelOp(layer) if "gate_up_proj" in prefix and mlp_tp_enable( ) and not is_moe_layer(prefix): return MLPColumnParallelOp(layer) @@ -636,7 +695,10 @@ def _get_row_parallel_op( prefix, layer ) -> Optional[Union[MLPRowParallelOp, OProjRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]]: + SequenceRowParallelOp, ShardedCPRowParallelOp]]: + if enable_dsa_cp() and "o_proj" in prefix: + if enable_dsa_cp_with_shard(): + return ShardedCPRowParallelOp(layer) if "down_proj" in prefix and mlp_tp_enable() and not is_moe_layer(prefix): return MLPRowParallelOp(layer) if "o_proj" in prefix and oproj_tp_enable(): @@ -667,10 +729,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct): and shared_expert_dp_enabled()): return None, 0, 1 custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, - MLPRowParallelOp, OProjRowParallelOp, - Flashcomm2OProjRowParallelOp, + ShardedCPColumnParallelOp, MLPRowParallelOp, + OProjRowParallelOp, Flashcomm2OProjRowParallelOp, MatmulAllreduceRowParallelOp, - SequenceRowParallelOp]] = None + SequenceRowParallelOp, ShardedCPRowParallelOp, + ShardedCPColumnParallelOp]] = None if direct == "row": custom_op = _get_row_parallel_op(prefix, layer) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 97f8e2b66cc..3252f7786b3 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1115,3 +1115,21 @@ def _check(name: str, config: dict): _check( "decode", vllm_config.kv_transfer_config.get_from_extra_config("decode", {})) + + +# Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. +@functools.cache +def enable_dsa_cp() -> bool: + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") + return is_ds_v32 and enable_sp() + + +@functools.cache +def enable_dsa_cp_with_shard() -> bool: + if not enable_dsa_cp(): + return False + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + return vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer