Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3453ed6
integrate FIA operator into mla_cp
Dec 24, 2025
b87e237
make it more readable
Dec 29, 2025
b27f1a0
adapt acl_graph in mla_cp FIA
Dec 31, 2025
8b7243e
adapt graph mode
Jan 5, 2026
1b98655
support mtp
Jan 6, 2026
5c61ba7
remove redundant attributes
Jan 6, 2026
fc08f5c
remove data cleaning
Jan 6, 2026
1035e4c
Update vllm_ascend/attention/context_parallel/mla_cp.py
845473182 Jan 7, 2026
f378a63
fix lint
Jan 7, 2026
bc01c37
fix lint
Jan 7, 2026
4cb19b8
fix lint
Jan 7, 2026
3624e99
fix ut
Jan 8, 2026
99e9cc0
fix lint
Jan 8, 2026
5c1f197
[Ops] replace _update_out_and_lse with _npu_attn_out_lse_update
Jan 6, 2026
1e3486a
fix pre-commit
Jan 20, 2026
362bdab
restore _process_attn_out_lse
Jan 20, 2026
22ad847
restore _process_attn_out_lse
Jan 20, 2026
9b55383
Merge branch 'releases/v0.13.0' of https://github.com/vllm-project/vl…
Jan 20, 2026
3ef5323
Revert "[Ops] replace _update_out_and_lse with _npu_attn_out_lse_update"
Jan 20, 2026
56cfbdf
restore mla_v1
Jan 20, 2026
265cdfb
remove redundant code
Jan 20, 2026
0aa4f8f
fix mla_cp
Jan 21, 2026
3cf8761
Merge branch 'releases/v0.13.0' of https://github.com/vllm-project/vl…
Jan 21, 2026
c266d51
Merge branch 'releases/v0.13.0' of https://github.com/vllm-project/vl…
Jan 21, 2026
1938da6
fix lint
Jan 21, 2026
8e95a79
Merge branch 'releases/v0.13.0' of https://github.com/vllm-project/vl…
Jan 22, 2026
4262ad9
Merge branch 'releases/v0.13.0' of https://github.com/vllm-project/vl…
Jan 22, 2026
fc321b1
Merge branch 'releases/v0.13.0' of https://github.com/vllm-project/vl…
Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,11 @@ def test_process_attn_out_lse(self):
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)

@patch('vllm_ascend.attention.context_parallel.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch("torch_npu.npu_fused_infer_attention_score")
@patch('torch_npu.npu_attention_update')
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention,
mock_npu_fused_infer_attention_score,
mock_get_forward_context):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
Expand All @@ -470,8 +470,8 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,

q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, BS, 1, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, BS, 1, self.impl.qk_rope_head_dim)
k_nope = torch.randn(NB, 1, BS, self.impl.kv_lora_rank)
k_pe = torch.randn(NB, 1, BS, self.impl.qk_rope_head_dim)

attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
Expand All @@ -485,7 +485,7 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,

mock_npu_attention_update.return_value = (torch.randn(
B, self.impl.num_heads, self.impl.kv_lora_rank), None)
mock_npu_multi_head_latent_attention.return_value = [
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank),
torch.randn(B, N, 1)
]
Expand Down
12 changes: 8 additions & 4 deletions tests/ut/compilation/test_acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def setUp(self):

@patch('torch.npu.graph_task_update_end', )
@patch('torch.npu.graph_task_update_begin', MagicMock())
@patch('torch_npu.atb.npu_multi_head_latent_attention', MagicMock())
@patch('torch_npu.npu_fused_infer_attention_score.out', MagicMock())
def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
input_positions = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
block_table = torch.zeros(2, 5, dtype=torch.long)
Expand Down Expand Up @@ -793,16 +793,20 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
qk_rope_head_dim = 32
qk_nope_head_dim = 64
query = torch.randn(4, num_heads, qk_head_dim)
q_pe = query[..., qk_nope_head_dim:]

q_nope = query[..., :qk_nope_head_dim]
q_pe = query[..., qk_rope_head_dim:]
k_nope = torch.randn(4, num_heads, qk_nope_head_dim)
k_pe = torch.randn(4, num_heads, qk_rope_head_dim)
input_layout = "BNSD"
actual_seq_lengths_kv = [1, 1]
out = torch.randn(2, 16, 128)
lse = torch.randn(2, 16, 8)
self.graph_params.attn_params[4] = []
self.graph_params.attn_params[4].append(
(q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads,
scale, num_kv_heads, out, lse))
(q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout,
None, 0, scale, block_table, 128, None, actual_seq_lengths_kv,
out, lse))

with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
Expand Down
128 changes: 81 additions & 47 deletions vllm_ascend/attention/context_parallel/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec

from vllm_ascend.attention.attention_v1 import AscendAttentionState

# isort: off
from vllm_ascend.attention.mla_v1 import (
AscendMLADecodeMetadata, AscendMLAImpl, AscendMLAMetadata,
Expand Down Expand Up @@ -259,8 +261,12 @@ def build_decode_metadata(
batch_seq_mask, non_blocking=True)
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
decode_metadata.cp_seq_len = cp_seq_len
decode_metadata.cp_seq_len = cp_seq_len.tolist()
decode_metadata.batch_seq_mask = batch_seq_mask

actual_seq_lengths_q = torch.arange(self.num_decodes_flatten) + 1
decode_metadata.actual_seq_lengths_q = actual_seq_lengths_q

return decode_metadata


Expand Down Expand Up @@ -575,19 +581,51 @@ def _forward_decode(
else:
num_heads = self.num_heads

k_nope = k_nope.view(-1, block_size, self.num_kv_heads,
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
k_nope = k_nope.view(-1, self.num_kv_heads, block_size,
self.kv_lora_rank)
k_pe = k_pe.view(-1, block_size, self.num_kv_heads,
k_pe = k_pe.view(-1, self.num_kv_heads, block_size,
self.qk_rope_head_dim)
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
seq_len = decode_meta.cp_seq_len

actual_seq_lengths = None
input_layout = "BNSD"

if (attn_metadata.attn_state in [
AscendAttentionState.SpecDecoding,
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.DecodeOnly,
] and self.speculative_config is not None):
input_layout = "TND"
# TODO: If the driver is upgraded later, the contiguous function can be deleted.
q_nope = q_nope.view(num_tokens, num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, num_heads, -1)
sparse_mode = 3
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
else:
q_nope = q_nope.view(num_tokens, num_heads, 1, -1).contiguous()
q_pe = q_pe.view(num_tokens, num_heads, 1, -1)
sparse_mode = 0
spec_attn_mask = None

common_kwargs = {
"return_lse": True,
"calc_type": "calc_type_ring",
"query_rope": q_pe,
"key_rope": k_pe,
"num_heads": num_heads,
"num_key_value_heads": self.num_kv_heads,
"input_layout": input_layout,
"atten_mask": spec_attn_mask,
"sparse_mode": sparse_mode,
"scale": self.scale,
"antiquant_mode": 0,
"antiquant_scale": None,
"block_table": decode_meta.block_table,
"block_size": block_size,
"actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.cp_seq_len,
"softmax_lse_flag": True,
}

forward_context: ForwardContext = get_forward_context()
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
Expand All @@ -601,57 +639,53 @@ def _forward_decode(
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
seq_len, num_heads, self.scale, self.num_kv_heads,
**common_kwargs)
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs)
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
if input_layout == "BNSD":
softmax_lse = torch.empty((num_tokens, num_heads, 1, 1),
dtype=torch.float,
device=q_nope.device)
else:
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=torch.float,
device=q_nope.device)

graph_params.attn_params[num_tokens].append(
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
decode_meta.block_table, seq_len, num_heads, self.scale,
self.num_kv_heads, weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse)))
(weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), num_heads,
self.num_kv_heads, input_layout,
weak_ref_tensors(spec_attn_mask) if spec_attn_mask is not None
else None, sparse_mode, self.scale,
weak_ref_tensors(decode_meta.block_table), block_size,
actual_seq_lengths, decode_meta.cp_seq_len,
weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse)))
torch.npu.graph_task_group_begin(stream)
torch_npu.atb.npu_multi_head_latent_attention(
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
k_nope,
**common_kwargs,
workspace=workspace,
output=attn_output,
lse=softmax_lse)
out=[attn_output, softmax_lse])
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output = torch.empty_like(q_nope)
softmax_lse = torch.empty((num_tokens, num_heads, 1),
dtype=q_nope.dtype,
device=q_nope.device)
torch_npu.atb.npu_multi_head_latent_attention(
attn_output, softmax_lse = torch_npu.npu_fused_infer_attention_score(
q_nope,
q_pe,
k_nope,
k_pe,
decode_meta.block_table,
seq_len,
num_heads,
self.scale,
self.num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
output=attn_output,
lse=softmax_lse)
k_nope,
**common_kwargs,
)
if input_layout == "BNSD":
B_attn, N_attn, S, D = attn_output.shape
B_lse, N_lse, Q_S, _ = softmax_lse.shape

attn_output = attn_output.permute(0, 2, 1, 3).reshape(
B_attn * S, N_attn, D)
softmax_lse = softmax_lse.permute(0, 2, 1, 3).reshape(
B_lse * Q_S, N_lse, 1)

# Update out&lse
attn_out_lse = _process_attn_out_lse(attn_output, softmax_lse,
Expand Down
66 changes: 44 additions & 22 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,38 +466,60 @@ def update_mla_attn_dcp_pcp_params(update_stream, forward_context,
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(q_nope, q_pe, k_nope, k_pe, block_table, seq_len, num_heads,
scale, num_kv_heads, attn_output, softmax_lse) = param
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
spec_attn_mask,
sparse_mode,
scale,
block_table,
block_size,
actual_seq_lengths,
actual_seq_lengths_kv,
attn_output,
softmax_lse,
) = param

decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if isinstance(seq_len, torch.Tensor):
seq_len = seq_len.tolist()
actual_seq_lengths_kv = seq_len

# For pcp + spec decode, we flatten seq_lens
# to avoid irregular attn_mask shape,
# so there's no need to divide runtime_shape by spec_multiple
pad_length = runtime_shape - len(seq_len)
pad_tensor = torch.zeros(pad_length,
dtype=seq_len.dtype,
device=seq_len.device)
seq_len = torch.cat([seq_len, pad_tensor], dim=0)
pad_length = runtime_shape - len(actual_seq_lengths_kv)
if pad_length > 0:
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (
runtime_shape - len(actual_seq_lengths_kv))

torch.npu.graph_task_update_begin(update_stream, handle)

torch_npu.atb.npu_multi_head_latent_attention(
torch_npu.npu_fused_infer_attention_score.out(
q_nope,
q_pe,
k_nope,
k_pe,
block_table,
seq_len,
num_heads,
scale,
num_kv_heads,
return_lse=True,
calc_type="calc_type_ring",
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(runtime_shape),
output=attn_output,
lse=softmax_lse)
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)

event.record(update_stream)
Expand Down