Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions tests/ut/compilation/test_acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
AscendMetadataForDecode)
from vllm_ascend.attention.context_parallel.attention_cp import \
AscendAttentionCPImpl
from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl
from vllm_ascend.attention.mla_v1 import (AscendMLADecodeMetadata,
AscendMLAMetadata)
from vllm_ascend.compilation.acl_graph import (
ACLGraphEntry, ACLGraphWrapper, get_draft_graph_params, get_graph_params,
set_draft_graph_params, set_graph_params, update_attn_dcp_pcp_params,
update_draft_graph_params_workspaces, update_mla_attn_dcp_pcp_params)
set_draft_graph_params, set_graph_params,
update_draft_graph_params_workspaces)


class TestACLGraphEntry(TestBase):
Expand Down Expand Up @@ -811,8 +814,9 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end):
out, lse))

with patch("torch_npu._C._npu_setStream", return_value=None):
update_mla_attn_dcp_pcp_params(self.update_stream, forward_context,
4)
AscendMlaCPImpl.update_graph_params(
self.update_stream, forward_context, 4
)

_mock_graph_task_end.assert_called_once()

Expand Down Expand Up @@ -852,6 +856,8 @@ def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end):
out, lse, 2, 0, 0))

with patch("torch_npu._C._npu_setStream", return_value=None):
update_attn_dcp_pcp_params(self.update_stream, forward_context, 4)
AscendAttentionCPImpl.update_graph_params(
self.update_stream, forward_context, 4, None
)

_mock_graph_task_end.assert_called_once()
12 changes: 6 additions & 6 deletions tests/ut/spec_decode/test_eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,11 @@ def test_dummy_run_with_prefill(self, mock_context, mock_get_context):
self.proposer.dummy_run(num_tokens=64, with_prefill=True, num_reqs=4)
self.assertTrue(self.proposer._runnable.call_count == 1)

@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
mock_update_attn_params):
mock_update_full_graph_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
Expand All @@ -352,14 +352,14 @@ def test_dummy_run_in_graph_capture(self, mock_context, mock_get_context,
in_graph_capturing=True,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1)
mock_update_attn_params.assert_not_called()
mock_update_full_graph_params.assert_not_called()
self.proposer.use_cuda_graph = last_use_cuda_graph

@patch("vllm_ascend.spec_decode.eagle_proposer.update_attn_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.update_full_graph_params")
@patch("vllm_ascend.spec_decode.eagle_proposer.get_forward_context")
@patch("vllm_ascend.spec_decode.eagle_proposer.set_ascend_forward_context")
def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
mock_update_attn_params):
mock_update_full_graph_params):
last_use_cuda_graph = self.proposer.use_cuda_graph
mock_return_context = MagicMock()
mock_return_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
Expand All @@ -374,7 +374,7 @@ def test_dummy_run_in_graph_run(self, mock_context, mock_get_context,
in_graph_capturing=False,
aclgraph_runtime_mode=CUDAGraphMode.FULL)
self.assertTrue(self.proposer._runnable.call_count == 1)
self.assertTrue(mock_update_attn_params.call_count == 1)
self.assertTrue(mock_update_full_graph_params.call_count == 1)
self.proposer.use_cuda_graph = last_use_cuda_graph


Expand Down
138 changes: 138 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,144 @@ def __init__(
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)

@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
if using_paged_attention(num_tokens, vllm_config):
# Paged Attention update logic
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens

workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
)
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace,
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
else:
# FIA update logic
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
attn_metadata = forward_context.draft_attn_metadatas
attn_keys = list(attn_metadata[0].keys())
else:
graph_params = get_graph_params()
attn_metadata = forward_context.attn_metadata
attn_keys = list(attn_metadata.keys())
# For Qwen3-next, since the kv_cache_config has already categorized
# linear_attn and self_attn, the attn_metadata is first arranged with
# self_attn followed by linear_attn. Therefore, using zip directly
# filters out the update operations for linear_attn.
# TODO: We use a new variable `attn_keys` to ensure the loop count is
# correct after get by `zip` because of the new structure of the attn_metadata
# when running with the merged full eagle-graph. Should check it with Qwen3-next.
num_layers = len(attn_keys)
if num_layers == 0:
return
if forward_context.is_draft_model:
attn_keys = attn_keys * (len(graph_params.attn_params[num_tokens]) // num_layers)
attn_count = 0
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
attn_keys,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
query,
key_cache,
value,
block_tables,
attn_mask,
block_size,
seq_lens,
query_start_loc,
num_kv_heads,
num_heads,
scale,
attn_output,
softmax_lse,
) = param

if forward_context.is_draft_model:
draft_step = attn_count // num_layers
seq_lens = attn_metadata[draft_step][key].seq_lens_list
actual_seq_lengths_q = attn_metadata[draft_step][key].actual_seq_lengths_q
attn_count = attn_count + 1
else:
seq_lens = attn_metadata[key].seq_lens_list
actual_seq_lengths_q = attn_metadata[key].actual_seq_lengths_q

torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)

event.record(update_stream)

def process_weights_after_loading(self, act_dtype: torch.dtype):
super().process_weights_after_loading(act_dtype)
if flashcomm2_oshard_manager.flashcomm2_oshard_enable():
Expand Down
73 changes: 73 additions & 0 deletions vllm_ascend/attention/context_parallel/attention_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,79 @@ def __init__(
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None

@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
q_nope,
k_nope,
value,
num_heads,
num_kv_heads,
scale,
block_table,
block_size,
actual_seq_lengths_kv,
actual_seq_lengths_q,
attn_output,
softmax_lse,
dcp_size,
pcp_rank,
dcp_rank,
) = param
attn_metadata = forward_context.attn_metadata[key]
actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank, dcp_rank]
pad_length = num_tokens - len(actual_seq_lengths_kv)
if pad_length > 0:
pad_tensor = np.zeros(pad_length, dtype=actual_seq_lengths_kv.dtype)
actual_seq_lengths_kv = np.concatenate([actual_seq_lengths_kv, pad_tensor])

actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q

if dcp_size > 1:
num_heads = num_heads * dcp_size

torch.npu.graph_task_update_begin(update_stream, handle)

torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
value,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout="TND",
atten_mask=None,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)

event.record(update_stream)

def _attention_with_nomask_and_mask(
self,
q: torch.Tensor,
Expand Down
79 changes: 79 additions & 0 deletions vllm_ascend/attention/context_parallel/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,85 @@ def __init__(
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
self.dcp_group = get_dcp_group().device_group if self.dcp_size > 1 else None

@staticmethod
def update_graph_params(
update_stream,
forward_context,
num_tokens,
vllm_config=None,
speculative_config=None,
num_dcp_pcp_tokens=None,
):
if forward_context.is_draft_model:
graph_params = get_draft_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
with torch.npu.stream(update_stream):
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[num_tokens],
graph_params.handles[num_tokens],
graph_params.events[num_tokens],
):
(
q_nope,
k_nope,
q_pe,
k_pe,
num_heads,
num_kv_heads,
input_layout,
spec_attn_mask,
sparse_mode,
scale,
block_table,
block_size,
actual_seq_lengths,
actual_seq_lengths_kv,
attn_output,
softmax_lse,
) = param

decode_meta = forward_context.attn_metadata[key].decode
seq_len = decode_meta.cp_seq_len
if isinstance(seq_len, torch.Tensor):
seq_len = seq_len.tolist()
actual_seq_lengths_kv = seq_len

pad_length = num_tokens - len(actual_seq_lengths_kv)
if pad_length > 0:
actual_seq_lengths_kv = actual_seq_lengths_kv + [0] * (num_tokens - len(actual_seq_lengths_kv))

torch.npu.graph_task_update_begin(update_stream, handle)

torch_npu.npu_fused_infer_attention_score.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=spec_attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_lse_flag=True,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
actual_seq_lengths=actual_seq_lengths,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)

event.record(update_stream)

def get_num_actual_tokens(self, attn_metadata: M):
if self.pcp_size > 1:
return attn_metadata.num_actual_tokens_pcp_padded // self.pcp_size
Expand Down
Loading
Loading