Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/ut/worker/test_pcp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens,
query_lens) - input_batch.num_computed_tokens_cpu

query_lens = torch.tensor(query_lens)
result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
result, _ = pcp_manager.generate_pcp_metadata(total_tokens, query_lens,
input_batch,
num_scheduled_tokens)
num_scheduled_tokens,
torch.tensor([]),
num_reqs_padded=num_reqs,
num_reqs=num_reqs)

if not expect_not_none:
assert result is None, f"Expected to return None, but got {type(result)}"
Expand Down
27 changes: 18 additions & 9 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,14 +537,14 @@ def _prepare_inputs(
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)

if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
self.input_batch.num_reqs,
)

# for pcp, prefill mtp should use origin scheduleroutput ,
if self.speculative_config and self.pcp_size * self.dcp_size > 1:
if self.speculative_config and self.use_cp:
self.pcp_manager.generate_pcp_mtp_input(
total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens,
Expand Down Expand Up @@ -703,7 +703,7 @@ def _prepare_inputs(
spec_decode_metadata = None
num_draft_tokens = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens)
logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True)
else:
Expand Down Expand Up @@ -925,7 +925,7 @@ def propose_draft_token_ids(
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)

req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
long_seq_metadata = self.long_seq_metadata # type: ignore
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
Expand Down Expand Up @@ -1798,11 +1798,17 @@ def _build_attention_metadata(

kv_cache_groups = self.kv_cache_config.kv_cache_groups

def _get_pcp_metadata(num_tokens):
def _get_pcp_metadata(block_table_tensor):
if not self.use_cp:
return None
return None, block_table_tensor
return self.pcp_manager.generate_pcp_metadata(
num_tokens, self.query_lens, self.input_batch, num_scheduled_tokens_np
num_tokens,
self.query_lens,
self.input_batch,
num_scheduled_tokens_np,
block_table_tensor,
num_reqs_padded,
num_reqs,
)

def _get_block_table_and_slot_mapping(kv_cache_gid: int):
Expand Down Expand Up @@ -1843,8 +1849,8 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int):
)
return blk_table_tensor, slot_mapping

self.long_seq_metadata = _get_pcp_metadata(num_tokens)
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
self.long_seq_metadata, block_table_gid_0 = _get_pcp_metadata(block_table_gid_0)

cm_base = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
Expand Down Expand Up @@ -2040,11 +2046,14 @@ def _dummy_run(
# LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora,
)
if self.pcp_size * self.dcp_size > 1:
if self.use_cp:
self.pcp_manager.init_batch_info(
num_scheduled_tokens,
num_reqs,
)
if self.speculative_config:
self.pcp_manager.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(num_scheduled_tokens)
self.pcp_manager.query_lens_pcp_full.copy_to_gpu()
if cudagraph_runtime_mode is None:
cudagraph_runtime_mode = _cudagraph_mode
else:
Expand Down
45 changes: 43 additions & 2 deletions vllm_ascend/worker/pcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from vllm.config import VllmConfig
from vllm.v1.utils import CpuGpuBuffer

from vllm_ascend.worker.npu_input_batch import NPUInputBatch

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput

Expand Down Expand Up @@ -514,13 +516,23 @@ def _get_cp_local_seq_lens(
dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size])
return dcp_local_seq_lens

def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens):
def generate_pcp_metadata(
self,
total_num_scheduled_tokens: int,
query_lens: torch.Tensor,
input_batch: "NPUInputBatch",
num_scheduled_tokens: np.ndarray | None,
block_table_tensor: torch.Tensor,
num_reqs_padded: int,
num_reqs: int,
):
from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata

num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
if self.pcp_world_size * self.dcp_world_size > 1:
assert num_scheduled_tokens is not None
decode_context_lens = (
input_batch.num_computed_tokens_cpu[: self.num_decode_reqs]
+ num_scheduled_tokens[: self.num_decode_reqs]
Expand All @@ -544,6 +556,7 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_ba
self.vllm_config.parallel_config.cp_kv_cache_interleave_size,
)
)
ori_query_lens_cpu = None
if self.decode_threshold > 1:
num_computed_tokens_of_pcp_dcp_list = []
if self.num_decode_reqs:
Expand All @@ -563,10 +576,37 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_ba
]
)
num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0)

# For pcp + spec decode, we flatten block_table
# to avoid irregular attn_mask shape, e.g.,
# num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1,
# ori block_table: # [d0, d1, p0, p1, p2]
# (num_reqs_d + num_reqs_p, max_num_blocks),
# flattened block_table: [d0, d0, d1, d1, p0, p1, p2]
# (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks),
ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded]
ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded]
num_prefill_reqs = self.num_prefill_reqs
num_decode_reqs = self.num_decode_reqs
num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item()
block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_(
block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone()
)
block_table_tensor[:num_decode_reqs_flatten].copy_(
block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0)
)
block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs]
if num_reqs_padded > num_reqs:
pad_size = num_reqs_padded - num_reqs
ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item())

long_seq_metadata = AscendPrefillContextParallelMetadata(
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(),
)
if ori_query_lens_cpu is not None:
long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu
long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item()
if self.pcp_world_size > 1:
q_head_idx, q_tail_idx = [], []
kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], []
Expand Down Expand Up @@ -685,8 +725,9 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_ba
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list
long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list
long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list

self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
return long_seq_metadata, block_table_tensor

def _list_to_tensor(self, lst, device, dtype=torch.int32):
tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)
Expand Down