Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 33 additions & 35 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_stream_switch,
npu_wait_tensor)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, npu_prefetch,
npu_stream_switch, npu_wait_tensor)

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -627,22 +627,25 @@ def __init__(
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
self.enable_multistream_mla = \
ascend_config.torchair_graph_config.enable_multistream_mla

# Adapt torch air graph mode with spec decoding.
speculative_config = get_current_vllm_config().speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0

def _v_up_proj_and_o_proj(self, x):
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,
x,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
return self.o_proj(x)[0]

# Return `ql_nope`, `q_pe`
Expand Down Expand Up @@ -933,20 +936,17 @@ def exec_kv(
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope

def exec_kv_prefill(
Expand Down Expand Up @@ -999,6 +999,7 @@ def _forward_decode(
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: Tuple[torch.Tensor],
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
Expand Down Expand Up @@ -1093,7 +1094,8 @@ def _forward_decode(
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output)
return self._v_up_proj_and_o_proj(attn_output,
enable_multistream_mla)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
Expand All @@ -1109,6 +1111,7 @@ def forward(
kv_cache: Tuple[torch.Tensor],
attn_metadata: M,
output: Optional[torch.Tensor] = None,
enable_multistream_mla=False,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
Expand Down Expand Up @@ -1158,27 +1161,21 @@ def forward(
if self.running_in_graph:
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
npu_wait_tensor(decode_hs_or_q_c,
cos,
enabled=self.enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=self.enable_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
with npu_stream_switch("mla_secondary",
0,
enabled=self.enable_multistream_mla):
enabled=enable_multistream_mla):
npu_wait_tensor(decode_q_pe,
decode_k_pe,
enabled=self.enable_multistream_mla)
enabled=enable_multistream_mla)
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
else:
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
Expand Down Expand Up @@ -1253,7 +1250,8 @@ def forward(
if self.running_in_graph:
return self._forward_decode(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
kv_cache, attn_metadata,
enable_multistream_mla)
else:
output_decode = self._forward_decode(decode_ql_nope,
decode_q_pe,
Expand Down
22 changes: 11 additions & 11 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
from vllm_ascend.utils import dispose_tensor, npu_prefetch


class CustomDeepseekV2SiluAndMul(SiluAndMul):
Expand Down Expand Up @@ -472,21 +471,22 @@ def forward(
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
enable_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and not forward_context.with_prefill
and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None:
npu_prefetch(self.q_a_proj.weight,
hidden_states,
enabled=enable_multistream_mla)
ckq = self.q_a_proj(hidden_states)[0]
use_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and attn_metadata.num_decodes > 0)
npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=use_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
hidden_states_or_q_c = self.q_a_layernorm(ckq)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared to #1353, there is a missing line of code here.

forward_kwargs['ckq'] = ckq

Is there any special consideration for not adding this line of code, or was it simply forgotten?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tensor ckq is passed into mla in order to add a npu_wait_tensor control edge before kv_a_proj_with_mqa of stream2, while in my testing scenario I found that removing this control edge results in a little better performance.

else:
hidden_states_or_q_c = hidden_states
is_mtp_model = attn_metadata is not None and attn_metadata.is_mtp_model
if self.torchair_graph_enabled and not is_mtp_model:
forward_kwargs = {}
if envs.VLLM_USE_V1:
output_shape = hidden_states.shape
output = torch.empty(output_shape,
Expand Down
13 changes: 13 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,19 @@ def npu_wait_tensor(self: torch.Tensor,
return _npu_wait_tensor(self, dependency) if enabled else self


def npu_prefetch(input: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
*,
enabled: bool = True):
if not enabled:
return
input_size = input.element_size() * input.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(input, dependency, max_size)


class AscendSocVersion(Enum):
A2 = 0
A3 = 1
Expand Down