Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import npu_prefetch, npu_stream_switch, npu_wait_tensor
from vllm_ascend.worker.npu_input_batch import InputBatch

if TYPE_CHECKING:
Expand Down Expand Up @@ -579,13 +579,18 @@
" please make sure after the tensor parallel split, num_heads / num_kv_heads in "
"{32, 64, 128}.")

def _v_up_proj_and_o_proj(self, x):
def _v_up_proj_and_o_proj(self, x, enable_multistream_mla: bool = False):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight,

Check warning on line 590 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L589-L590

Added lines #L589 - L590 were not covered by tests
x,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla)
return self.o_proj(x, is_prefill=False)[0]

# Return `ql_nope`, `q_pe`
Expand Down Expand Up @@ -864,7 +869,6 @@
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
enable_multistream_mla: bool = False,
):

B = hidden_states.shape[0]
Expand All @@ -874,21 +878,18 @@
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(

Check warning on line 881 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L881

Added line #L881 was not covered by tests
kv,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope, kv

Check warning on line 892 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L892

Added line #L892 was not covered by tests

def exec_kv_prefill(
self,
Expand Down Expand Up @@ -940,6 +941,7 @@
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: AscendMLAMetadata,
enable_multistream_mla: bool = False,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
Expand Down Expand Up @@ -1020,7 +1022,8 @@
out=attn_output)
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None:
return self._v_up_proj_and_o_proj(attn_output)
return self._v_up_proj_and_o_proj(attn_output,

Check warning on line 1025 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1025

Added line #L1025 was not covered by tests
enable_multistream_mla)
else:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
Expand All @@ -1037,6 +1040,7 @@
attn_metadata: M,
output: Optional[torch.Tensor] = None,
enable_multistream_mla: bool = False,
ckq: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
Expand Down Expand Up @@ -1091,6 +1095,15 @@
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
with npu_stream_switch("mla_secondary",

Check warning on line 1098 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1098

Added line #L1098 was not covered by tests
0,
enabled=enable_multistream_mla):
npu_wait_tensor(hidden_states_or_kv_c_normed,

Check warning on line 1101 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1101

Added line #L1101 was not covered by tests
ckq,
enabled=enable_multistream_mla)
decode_k_pe, decode_k_nope, decode_kv = self.exec_kv(

Check warning on line 1104 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1104

Added line #L1104 was not covered by tests
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping)
# Without explicitly controlling the order, IndexByTensor operations
# would be placed after `matmul W_KV_T` hindering the overlapping of
# KvRmsNormRopeCache and SingleRope.
Expand All @@ -1100,12 +1113,13 @@
npu_wait_tensor(decode_hs_or_q_c,
sin,
enabled=enable_multistream_mla)
npu_wait_tensor(decode_hs_or_q_c,

Check warning on line 1116 in vllm_ascend/attention/mla_v1.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/attention/mla_v1.py#L1116

Added line #L1116 was not covered by tests
decode_kv,
enabled=enable_multistream_mla)

decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
if self.running_in_graph:
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,
attn_metadata.slot_mapping, enable_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
Expand Down Expand Up @@ -1194,7 +1208,8 @@
if self.running_in_graph:
return self._forward_decode(decode_ql_nope, decode_q_pe,
decode_k_nope, decode_k_pe,
kv_cache, attn_metadata)
kv_cache, attn_metadata,
enable_multistream_mla)
else:
output_decode = self._forward_decode(decode_ql_nope,
decode_q_pe,
Expand Down
13 changes: 6 additions & 7 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
from vllm_ascend.utils import dispose_tensor, npu_prefetch


class CustomDeepseekV2SiluAndMul(SiluAndMul):
Expand Down Expand Up @@ -567,12 +566,12 @@
and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None:
npu_prefetch(self.q_a_proj.weight,

Check warning on line 569 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L569

Added line #L569 was not covered by tests
hidden_states,
enabled=enable_multistream_mla)
ckq = self.q_a_proj(hidden_states)[0]
npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla)
with npu_stream_switch("mla_secondary",
0,
enabled=enable_multistream_mla):
hidden_states_or_q_c = self.q_a_layernorm(ckq)
hidden_states_or_q_c = self.q_a_layernorm(ckq)
forward_kwargs['ckq'] = ckq

Check warning on line 574 in vllm_ascend/models/deepseek_v2.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/models/deepseek_v2.py#L573-L574

Added lines #L573 - L574 were not covered by tests
else:
hidden_states_or_q_c = hidden_states
if self.torchair_graph_enabled:
Expand Down
14 changes: 14 additions & 0 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,20 @@
return _npu_wait_tensor(self, dependency) if enabled else self


# TODO(wxy): Move to ops module
def npu_prefetch(input: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
*,
enabled: bool = True):
if not enabled:
return
input_size = input.element_size() * input.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(input, dependency, max_size)

Check warning on line 430 in vllm_ascend/utils.py

View check run for this annotation

Codecov / codecov/patch

vllm_ascend/utils.py#L425-L430

Added lines #L425 - L430 were not covered by tests


# TODO(zzzzwwjj): move this into forward_context
class FusedMoEState(Enum):
AllGather = 0
Expand Down
Loading