Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6ae53b5
integrate FIA operator into mla_cp
Dec 24, 2025
08de021
make it more readable
Dec 29, 2025
048b04f
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Dec 30, 2025
daafaff
adapt acl_graph in mla_cp FIA
Dec 31, 2025
cab49ba
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Dec 31, 2025
452c663
adapt graph mode
Jan 5, 2026
6733ce3
support mtp
Jan 6, 2026
3650848
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Jan 6, 2026
410be4d
remove redundant attributes
Jan 6, 2026
8d06f81
remove data cleaning
Jan 6, 2026
1352315
Update vllm_ascend/attention/context_parallel/mla_cp.py
845473182 Jan 7, 2026
47072e3
fix lint
Jan 7, 2026
120ac20
Merge branch 'FIA_rebase' of https://github.com/845473182/vllm-ascend…
Jan 7, 2026
7e899c6
fix lint
Jan 7, 2026
40afa15
fix lint
Jan 7, 2026
4134757
Merge branch 'main' into FIA_rebase
845473182 Jan 8, 2026
c3f5465
fix ut
Jan 8, 2026
b559ab0
Merge branch 'FIA_rebase' of https://github.com/845473182/vllm-ascend…
Jan 8, 2026
92436a2
fix lint
Jan 8, 2026
a2a6f72
[Ops] replace _update_out_and_lse with _npu_attn_out_lse_update
Jan 6, 2026
6a563e2
Merge branch 'ops' of https://github.com/YzTongNiar/vllm-ascend into …
Jan 9, 2026
73976cb
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Jan 19, 2026
7b1dd4a
fix pre-commit
Jan 20, 2026
bba3ddf
restore _process_attn_out_lse
Jan 20, 2026
92b50c3
restore _process_attn_out_lse
Jan 20, 2026
c51a43b
fix ut
Jan 20, 2026
0d80040
Revert "[Ops] replace _update_out_and_lse with _npu_attn_out_lse_update"
Jan 20, 2026
188edfa
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Jan 21, 2026
a22aa13
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Jan 21, 2026
8b2138c
Merge branch 'main' of https://github.com/vllm-project/vllm-ascend in…
Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,11 @@ def test_process_attn_out_lse(self):
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)

@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch("torch_npu.npu_fused_infer_attention_score")
@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention,
mock_npu_fused_infer_attention_score,
mock_get_forward_context):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
Expand All @@ -470,8 +470,8 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,

q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, 1, BS, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, 1, BS, self.impl.qk_rope_head_dim)

attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
Expand All @@ -485,7 +485,7 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,

mock_npu_attention_update.return_value = (torch.randn(
B, self.impl.num_heads, self.impl.kv_lora_rank), None)
mock_npu_multi_head_latent_attention.return_value = [
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank),
torch.randn(B, N, 1)
]
Expand Down
12 changes: 8 additions & 4 deletions tests/ut/compilation/test_acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def setUp(self):

@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long)
Expand Down Expand Up @@ -793,16 +793,20 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
qk_rope_head_dim = 32
qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim)
q_pe = query[..., qk_nope_head_dim:]

q_nope = query[..., :qk_nope_head_dim]
q_pe = query[..., qk_rope_head_dim:]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
input_layout = "BNSD"
actual_seq_lengths_kv = [1, 1]
out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8)
self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append(
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
scale, num_kv_heads, out, lse))
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
out, lse))

with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
Expand Down
127 changes: 77 additions & 50 deletions vllm_ascend/attention/context_parallel/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec

from vllm_ascend.attention.attention_v1 import AscendAttentionState

# isort: off
from vllm_ascend.attention.mla_v1 import (
AscendMLADecodeMetadata,
Expand Down Expand Up @@ -244,8 +246,12 @@ def build_decode_metadata(
self.batch_seq_mask_buf[: batch_seq_mask.shape[0]].copy_(batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[: batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
decode_metadata.cp_seq_len = cp_seq_len
decode_metadata.cp_seq_len = cp_seq_len.tolist()
decode_metadata.batch_seq_mask = batch_seq_mask

actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1
decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q

return decode_metadata


Expand Down Expand Up @@ -535,18 +541,53 @@ def _forward_decode(
num_heads = self.num_heads * self.dcp_size
else:
num_heads = self.num_heads

k_nope = k_nope.view(-1, block_size, self.num_kv_heads, self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads, self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
seq_len = decode_meta.cp_seq_len
k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank)
k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim)

actual_seq_lengths = None
input_layout = "BNSD"

if (
attn_metadata.attn_state
in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
]
and self.speculative_config is not None
):
input_layout = "TND"
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
q_nope = q_nope.view(num_tokens, num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, num_heads, -1)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
q_nope = q_nope.view(num_tokens, num_heads, 1, -1).contiguous()
q_pe = q_pe.view(num_tokens, num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None

common_kwargs = {
"return_lse": True,
"calc_type": "calc_type_ring",
"query_rope": q_pe,
"key_rope": k_pe,
"num_heads": num_heads,
"num_key_value_heads": self.num_kv_heads,
"input_layout": input_layout,
"atten_mask": spec_attn_mask,
"sparse_mode": sparse_mode,
"scale": self.scale,
"antiquant_mode": 0,
"antiquant_scale": None,
"block_table": decode_meta.block_table,
"block_size": block_size,
"actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.cp_seq_len,
"softmax_lse_flag": True,
}

forward_context: ForwardContext = get_forward_context()
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
Expand All @@ -560,72 +601,58 @@ def _forward_decode(
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
k_nope,
**common_kwargs,
)
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
if input_layout == "BNSD":
softmax_lse = torch.empty((num_tokens, num_heads, 1, 1), dtype=torch.float, device=q_nope.device)
else:
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=torch.float, device=q_nope.device)

graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(q_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_pe),
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None else None,
sparse_mode,
self.scale,
weak_ref_tensors(decode_meta.block_table),
block_size,
actual_seq_lengths,
decode_meta.cp_seq_len,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
)
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse,
torch_npu.npu_fused_infer_attention_score.out(
q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1), dtype=q_nope.dtype, device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
attn_output, softmax_lse = torch_npu.npu_fused_infer_attention_score(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse,
k_nope,
**common_kwargs,
)
if input_layout == "BNSD":
B_attn, N_attn, S, D = attn_output.shape
B_lse, N_lse, Q_S, _ = softmax_lse.shape

attn_output = attn_output.permute(0, 2, 1, 3).reshape(B_attn * S, N_attn, D)
softmax_lse = softmax_lse.permute(0, 2, 1, 3).reshape(B_lse * Q_S, N_lse, 1)

# Update out&lse
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse, decode_meta.batch_seq_mask)
Expand Down
55 changes: 33 additions & 22 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,45 +487,56 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape
):
(
q_nope,
q_pe,
k_nope,
q_pe,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
input_layout,
spec_attn_mask,
sparse_mode,
scale,
block_table,
block_size,
actual_seq_lengths,
actual_seq_lengths_kv,
attn_output,
softmax_lse,
) = param

decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if isinstance(seq_len, torch.Tensor):
seq_len = seq_len.tolist()
actual_seq_lengths_kv = seq_len
Comment on lines 508 to +511
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Dynamic parameters such as block_table, spec_attn_mask, and actual_seq_lengths are not updated during graph replay. They are read from the param tuple which contains values from the time of graph capture. This will cause the replayed graph to execute with stale data, leading to incorrect attention outputs. These parameters must be updated from the current forward_context at every step, similar to how actual_seq_lengths_kv is being updated.

            block_table = decode_meta.block_table
            spec_attn_mask = decode_meta.attn_mask
            actual_seq_lengths = decode_meta.actual_seq_lengths_q
            seq_len = decode_meta.cp_seq_len
            if isinstance(seq_len, torch.Tensor):
                seq_len = seq_len.tolist()
            actual_seq_lengths_kv = seq_len


# For pcp + spec decode, we flatten seq_lens
# to avoid irregular attn_mask shape,
# so there's no need to divide runtime_shape by spec_multiple
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length, dtype=seq_len.dtype, device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0:
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (runtime_shape - len(actual_seq_lengths_kv))

torch.npu.graph_task_update_begin(update_stream, handle)

torch_npu.atb.npu_multi_head_latent_attention(
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse,
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)

Expand Down