diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index d8abae6080e..570db691163 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -22,8 +22,8 @@ from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_stream_switch, - npu_wait_tensor) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_prefetch, + npu_stream_switch, npu_wait_tensor) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -627,8 +627,6 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - self.enable_multistream_mla = \ - ascend_config.torchair_graph_config.enable_multistream_mla # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config @@ -636,13 +634,18 @@ def __init__( self.spec_token_num = speculative_config.num_speculative_tokens assert self.spec_token_num > 0 - def _v_up_proj_and_o_proj(self, x): + def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB + npu_prefetch(self.o_proj.weight, + x, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=enable_multistream_mla) return self.o_proj(x)[0] # Return `ql_nope`, `q_pe` @@ -933,20 +936,17 @@ def exec_kv( # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" - with npu_stream_switch("mla_secondary", - 0, - enabled=self.enable_multistream_mla): - k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - kv, - self.kv_a_layernorm.weight, - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode=cache_mode, - ) + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) return k_pe, k_nope def exec_kv_prefill( @@ -999,6 +999,7 @@ def _forward_decode( k_pe: torch.Tensor, kv_c_and_k_pe_cache: Tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, + enable_multistream_mla: bool = False, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None @@ -1093,7 +1094,8 @@ def _forward_decode( out=attn_output) current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: - return self._v_up_proj_and_o_proj(attn_output) + return self._v_up_proj_and_o_proj(attn_output, + enable_multistream_mla) else: current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): @@ -1109,6 +1111,7 @@ def forward( kv_cache: Tuple[torch.Tensor], attn_metadata: M, output: Optional[torch.Tensor] = None, + enable_multistream_mla=False, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: @@ -1158,27 +1161,21 @@ def forward( if self.running_in_graph: cos = attn_metadata.decode.cos sin = attn_metadata.decode.sin - # Without explicitly controlling the order, IndexByTensor operations - # would be placed after `matmul W_KV_T` hindering the overlapping of - # KvRmsNormRopeCache and SingleRope. - npu_wait_tensor(decode_hs_or_q_c, - cos, - enabled=self.enable_multistream_mla) - npu_wait_tensor(decode_hs_or_q_c, - sin, - enabled=self.enable_multistream_mla) + with npu_stream_switch("mla_secondary", + 0, + enabled=enable_multistream_mla): + decode_k_pe, decode_k_nope = self.exec_kv( + hidden_states_or_kv_c_normed, cos, sin, kv_cache, + attn_metadata.slot_mapping) decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) if self.running_in_graph: - decode_k_pe, decode_k_nope = self.exec_kv( - hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) with npu_stream_switch("mla_secondary", 0, - enabled=self.enable_multistream_mla): + enabled=enable_multistream_mla): npu_wait_tensor(decode_q_pe, decode_k_pe, - enabled=self.enable_multistream_mla) + enabled=enable_multistream_mla) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( @@ -1253,7 +1250,8 @@ def forward( if self.running_in_graph: return self._forward_decode(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe, - kv_cache, attn_metadata) + kv_cache, attn_metadata, + enable_multistream_mla) else: output_decode = self._forward_decode(decode_ql_nope, decode_q_pe, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index e913777fd0e..fb1ed6f11b7 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -68,8 +68,7 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod -from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, - npu_wait_tensor) +from vllm_ascend.utils import dispose_tensor, npu_prefetch class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -472,21 +471,22 @@ def forward( hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + enable_multistream_mla = (self.enable_multistream_mla + and attn_metadata is not None + and not forward_context.with_prefill + and attn_metadata.num_decodes > 0) + forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} if self.q_lora_rank is not None: + npu_prefetch(self.q_a_proj.weight, + hidden_states, + enabled=enable_multistream_mla) ckq = self.q_a_proj(hidden_states)[0] - use_multistream_mla = (self.enable_multistream_mla - and attn_metadata is not None - and attn_metadata.num_decodes > 0) - npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla) - with npu_stream_switch("mla_secondary", - 0, - enabled=use_multistream_mla): - hidden_states_or_q_c = self.q_a_layernorm(ckq) + hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model if self.torchair_graph_enabled and not is_mtp_model: - forward_kwargs = {} if envs.VLLM_USE_V1: output_shape = hidden_states.shape output = torch.empty(output_shape, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f7ca0aba2e3..c63594edc16 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -303,6 +303,19 @@ def npu_wait_tensor(self: torch.Tensor, return _npu_wait_tensor(self, dependency) if enabled else self +def npu_prefetch(input: torch.Tensor, + dependency: torch.Tensor, + max_size: int = 0, + *, + enabled: bool = True): + if not enabled: + return + input_size = input.element_size() * input.numel() + if max_size <= 0 or max_size > input_size: + max_size = input_size + torch_npu.npu_prefetch(input, dependency, max_size) + + class AscendSocVersion(Enum): A2 = 0 A3 = 1