Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions tests/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ def forward_native(


# test with leading dimension and merge seqlen and batch_size as num_tokens
# TODO(ganyi): open this test in the future
@pytest.mark.skip(
reason=
"skip this test by default for now because of ci issue, will enable it in the future"
)
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
Expand Down
134 changes: 90 additions & 44 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class AscendMLAPrefillMetadata:
input_positions: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_context_len: int
max_seq_lens: int


@dataclass
Expand All @@ -65,6 +65,7 @@ class AscendMLADecodeMetadata:
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int


@dataclass
Expand Down Expand Up @@ -131,11 +132,6 @@ def __init__(self,
self.runner = runner
scheduler_config = runner.scheduler_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
# self.attn_mask = None
# if AscendMLAMetadataBuilder._attn_mask_builder is None:
# AscendMLAMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
# 128, self.runner.model_config.dtype
# )

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
Expand Down Expand Up @@ -222,12 +218,14 @@ def build(self,
num_reqs]
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_context_len = seq_lens.max().item()
max_seq_lens = seq_lens.max().item()

prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
tokens_start = self._num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query_lens is a device tensor? if so, many D2H here, is this operation necessary?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query_lens is actually a cpu tensor, so no d2h operation will happened here, you can refer to line 220

max_seq_lens = seq_lens[tokens_start:].max().item()

prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
Expand All @@ -236,15 +234,17 @@ def build(self,
input_positions=input_positions[tokens_start:],
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_context_len=max_context_len,
max_seq_lens=max_seq_lens,
)

decode_metadata = None
if self._num_decodes > 0:
max_seq_lens = seq_lens[:self._num_decodes].max().item()
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions[:self._num_decode_tokens],
block_table=block_table[:self._num_decode_tokens, ...],
seq_lens=seq_lens[:self._num_decode_tokens])
seq_lens=seq_lens[:self._num_decode_tokens],
max_seq_lens=max_seq_lens)

return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
Expand Down Expand Up @@ -306,12 +306,18 @@ def __init__(
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
# TODO: below padding should be removed after kernel is ready
# we found npu_flash_attention can only works on 128 divisible head_dim, we pad it to target size here
# and slice the final result to guarantee its functionality.
self.padding_head_dim = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In prefill, we use MHA for computation, then the head_dim = nope_dim + rope_dim (192), while in decode, the absorbed and move_elision strategies are adopt, the head_dim=nope_dim, and we don't need pad, am I right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are definately right, this padding dim is used for prefill to padding the tensor. Not just for v_head_dim vs (qk_rope + qk_nope), but also for the 128 divisble head_dim alignment requirements for the _npu_flash_attention

(self.qk_nope_head_dim + self.qk_rope_head_dim - 1) // 128 +
1) * 128

# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly
# TODO(lucas): we should probably find a cleaner way to do this
self.rotary_emb = rotary_emb.forward_native
self.rotary_emb = rotary_emb

self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
Expand Down Expand Up @@ -409,37 +415,73 @@ def _forward_prefill(
) -> torch.Tensor:
assert attn_metadata.prefill is not None

# TODO: enable this compute for flash attention computation
# kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
# -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
# k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# v_padded = torch.nn.functional.pad(v, [0, query.shape[-1] - v.shape[-1]],
# value=0)
num_tokens = query.size(0)
attn_output = torch.empty(num_tokens,
self.num_heads,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
output=attn_output,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_context_len,
nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
scale=self.scale,
alibi_slopes=None,
causal=True)
attn_output = attn_output.view(
attn_output = None
# Here is only 2 possibility of input, ChunkedPrefill or PrefillOnly
if attn_metadata.attn_state == AscendAttentionState.ChunkedPrefill:
attn_output = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
output=attn_output,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
query_lens=attn_metadata.prefill.query_lens,
context_lens=attn_metadata.prefill.context_lens,
kv_b_proj=self.kv_b_proj,
max_query_len=attn_metadata.prefill.max_query_len,
max_context_len=attn_metadata.prefill.max_seq_lens,
nope_dim=self.qk_nope_head_dim,
rope_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
scale=self.scale,
alibi_slopes=None,
causal=True)
elif attn_metadata.attn_state == AscendAttentionState.PrefillOnly:
attn_output = torch.empty(num_tokens,
self.num_heads,
self.padding_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
pad_query = torch.nn.functional.pad(query, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_key = torch.nn.functional.pad(key, [
0, self.padding_head_dim - self.qk_rope_head_dim -
self.qk_nope_head_dim
],
value=0)
pad_value = torch.nn.functional.pad(
value, [0, self.padding_head_dim - self.v_head_dim], value=0)
torch_npu._npu_flash_attention(
query=pad_query,
key=pad_key,
value=pad_value,
mask=attn_metadata.attn_mask,
seq_len=attn_metadata.prefill.context_lens,
scale_value=self.scale,
num_heads=self.num_heads,
num_kv_heads=self.num_heads,
out=attn_output)
attn_output = attn_output.view(
-1, self.num_heads,
self.padding_head_dim)[:, :, :self.v_head_dim]
else:
raise RuntimeError(
"Unexpected path reached, AscendMLAImpl should only have PrefillOnly and ChunkedPrefill scenario in forward prefill, please file a bug to vllm-ascend !"
)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
return self.o_proj(attn_output)[0]

Expand All @@ -457,7 +499,7 @@ def _forward_decode(

q = torch.cat([q_nope, q_pe], dim=-1)
num_tokens = q.size(0)
attn_output = torch.randn(
attn_output = torch.empty(
[num_tokens, self.num_heads, self.kv_lora_rank],
dtype=q.dtype,
device=q.device)
Expand Down Expand Up @@ -522,8 +564,10 @@ def forward(
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe)
attn_metadata.decode.input_positions,
decode_q_pe.contiguous(),
decode_k_pe,
max_seq_len=attn_metadata.decode.max_seq_lens)

if has_prefill:
assert attn_metadata.prefill is not None
Expand All @@ -533,7 +577,9 @@ def forward(

prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(), prefill_k_pe)
prefill_q_pe.contiguous(),
prefill_k_pe,
max_seq_len=attn_metadata.prefill.max_seq_lens)

if kv_cache.numel() > 0:
key = torch.cat([
Expand Down
Loading