Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 98 additions & 47 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.distributed.utils import all_gather_async
from vllm_ascend.ops.layer_shard_linear import (
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
reach_layer_for_shard_weight_series,
Expand Down Expand Up @@ -226,7 +227,10 @@ def build(

cos = cos[local_start:local_end_with_pad]
sin = sin[local_start:local_end_with_pad]
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
fill_value=-1,
dtype=slot_mapping.dtype,
device=slot_mapping.device)
assert cos.shape[0] == num_tokens_per_device, \
f"cos.shape[0] must be equal to num_tokens_per_device, \
got {cos.shape[0]} and {num_tokens_per_device}"
Expand Down Expand Up @@ -503,7 +507,6 @@ def exec_kv(
sin: torch.Tensor,
kv_cache: Tuple,
slots: torch.Tensor,
slots_cp: Optional[torch.Tensor],
):
B = kv_no_split.shape[0]
N = self.num_kv_heads
Expand All @@ -514,30 +517,19 @@ def exec_kv(
cache_mode = "PA"

if self.enable_sfa_cp:
assert slots_cp is not None
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
slots_cp.to(torch.int64),
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
is_output_kv=True,
)
# TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
k_pe = get_tp_group().all_gather(k_pe, 0)
k_nope = get_tp_group().all_gather(k_nope, 0)

if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(
kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1),
k_nope.view(-1, k_nope.shape[-1]))
torch_npu.npu_scatter_nd_update_(
kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1),
k_pe.view(-1, k_pe.shape[-1]))
return k_pe, k_nope
else:
torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
Expand All @@ -550,6 +542,7 @@ def exec_kv(
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return None, None

def rope_single(
self,
Expand Down Expand Up @@ -742,6 +735,7 @@ def forward(
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
return output.fill_(0)

has_prefill = attn_metadata.has_prefill
cos = attn_metadata.cos
sin = attn_metadata.sin
Expand All @@ -762,6 +756,12 @@ def forward(
need_gather_q_kv=need_gather_q_kv,
num_input_tokens=attn_metadata.num_input_tokens,
)
q, k = self.indexer_select_pre_process(
x=hidden_states,
qr=q_c,
cos=cos,
sin=sin,
need_gather_q_kv=need_gather_q_kv)
else:
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
Expand All @@ -780,38 +780,75 @@ def forward(
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)

q, k = self.indexer_select_pre_process(
x=hidden_states,
qr=q_c,
cos=cos,
sin=sin,
need_gather_q_kv=need_gather_q_kv)

if has_prefill:
wait_for_kv_layer_from_connector(layer_name)

slot_mapping = attn_metadata.slot_mapping
slot_mapping_cp = None
if self.enable_sfa_cp:
assert attn_metadata.sfa_cp_context is not None
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key

self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
slot_mapping_cp)
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
slot_mapping)

if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)
if self.enable_sfa_cp:
assert k_pe is not None
assert k_nope is not None
# support all_gather kv async for communication calculation overlap
fused_kv_no_split, kv_ag_handle = all_gather_async(
torch.cat([
k_pe.view(-1, k_pe.shape[-1]),
k_nope.view(-1, k_nope.shape[-1]),
k.view(-1, k.shape[-1])
],
dim=1), get_tp_group())

ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
q_pe = self.rope_single(q_pe, cos, sin)

topk_indices = self.indexer_select(
if self.enable_sfa_cp:
if kv_ag_handle is not None:
kv_ag_handle.wait()
for layer in (self.layer_sharding_kwargs or []):
if is_hidden_layer(layer):
reach_layer_for_shard_weight_series(layer)

if kv_cache is not None:
assert fused_kv_no_split is not None
k_pe, k_nope, k = fused_kv_no_split.split([
self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim
],
dim=-1)
slot_mapping = attn_metadata.slot_mapping.view(-1, 1)
torch_npu.npu_scatter_nd_update_(
kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping,
k_nope)
torch_npu.npu_scatter_nd_update_(
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
k_pe)

topk_indices = self.indexer_select_post_process(
x=hidden_states,
qr=q_c,
q=q,
k=k,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
cos=cos,
sin=sin,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
need_gather_q_kv=need_gather_q_kv)

attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
query=ql_nope,
key=kv_cache[0],
Expand All @@ -828,6 +865,7 @@ def forward(
layout_kv="PA_BSND",
sparse_mode=3,
)

attn_output = self._v_up_proj(attn_output, has_prefill)
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=attn_output,
Expand All @@ -836,29 +874,24 @@ def forward(
output[...] = self.o_proj(attn_output)[0]
return output_padded

def indexer_select(
def indexer_select_pre_process(
self,
x: torch.Tensor,
qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
cos: torch.Tensor,
sin: torch.Tensor,
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
need_gather_q_kv: bool = False,
):
# q process in new stream
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]

k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
k = self.k_norm(k_proj).unsqueeze(1)
k = k.view(-1, 1, self.head_dim)

if HAS_TRITON:
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]

cos = cos.view(-1, self.qk_rope_head_dim)
sin = sin.view(-1, self.qk_rope_head_dim)
q, k = rope_forward_triton(q,
Expand All @@ -868,6 +901,38 @@ def indexer_select(
rope_dim=self.qk_rope_head_dim,
is_neox_style=True)
else:
k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]

k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)

k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
q = None

return q, k

def indexer_select_post_process(
self,
x: torch.Tensor,
qr: torch.Tensor,
q: Optional[torch.Tensor],
k: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
cos: torch.Tensor,
sin: torch.Tensor,
actual_seq_lengths_query: torch.Tensor,
actual_seq_lengths_key: torch.Tensor,
need_gather_q_kv: bool = False,
):
if q is None:
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]

cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
Expand All @@ -882,20 +947,6 @@ def indexer_select(
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]

k_pe, k_nope = torch.split(
k,
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]

k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)

k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]

if self.enable_sfa_cp:
k = get_tp_group().all_gather(k, 0)

if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
Expand Down
21 changes: 20 additions & 1 deletion vllm_ascend/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from typing import Optional

import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
from vllm.forward_context import get_forward_context

from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
Expand Down Expand Up @@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
offset += num_tokens_dp
x = result
return x


def all_gather_async(input: torch.Tensor,
group: GroupCoordinator,
output: Optional[torch.Tensor] = None,
async_op: bool = True):
if group.world_size == 1:
return input, None
if output is None:
input_size = input.size()
output_size = (input_size[0] * group.world_size, ) + input_size[1:]
output = torch.empty(output_size,
dtype=input.dtype,
device=input.device)
return output, dist.all_gather_into_tensor(output,
input,
group=group.device_group,
async_op=async_op)
19 changes: 6 additions & 13 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,20 +1165,13 @@ def get_instance(*args, **kwargs):
@lru_cache(maxsize=1)
def enable_dsa_cp() -> bool:
from vllm.config import get_current_vllm_config

vllm_config = get_current_vllm_config()
if vllm_config is None:
return False

model_config = getattr(vllm_config, "model_config", None)
if model_config is None:
return False

hf_text_config = getattr(model_config, "hf_text_config", None)
if hf_text_config is None:
return False

return hasattr(hf_text_config, "index_topk")
is_ds_v32 = hasattr(
vllm_config.model_config, "hf_text_config") and hasattr(
vllm_config.model_config.hf_text_config, "index_topk")
if is_ds_v32 and enable_sp():
return True
return False


@lru_cache(maxsize=1)
Expand Down