diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index af37472962d..d01bcafaba8 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -73,9 +73,12 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, query_lens) - input_batch.num_computed_tokens_cpu query_lens = torch.tensor(query_lens) - result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, + result, _ = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, input_batch, - num_scheduled_tokens) + num_scheduled_tokens, + torch.tensor([]), + num_reqs_padded=num_reqs, + num_reqs=num_reqs) if not expect_not_none: assert result is None, f"Expected to return None, but got {type(result)}" diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7ecebb60ae8..5a7b94be871 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -537,14 +537,14 @@ def _prepare_inputs( self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - if self.pcp_size * self.dcp_size > 1: + if self.use_cp: self.pcp_manager.init_batch_info( num_scheduled_tokens, self.input_batch.num_reqs, ) # for pcp, prefill mtp should use origin scheduleroutput , - if self.speculative_config and self.pcp_size * self.dcp_size > 1: + if self.speculative_config and self.use_cp: self.pcp_manager.generate_pcp_mtp_input( total_num_scheduled_tokens, scheduler_output.num_scheduled_tokens, @@ -703,7 +703,7 @@ def _prepare_inputs( spec_decode_metadata = None num_draft_tokens = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) - if self.pcp_size * self.dcp_size > 1: + if self.use_cp: logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens) logits_indices = logits_indices.pin_memory().to(self.device, non_blocking=True) else: @@ -925,7 +925,7 @@ def propose_draft_token_ids( self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count) req_scheduled_tokens = scheduler_output.num_scheduled_tokens - if self.pcp_size * self.dcp_size > 1: + if self.use_cp: long_seq_metadata = self.long_seq_metadata # type: ignore input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu @@ -1798,11 +1798,17 @@ def _build_attention_metadata( kv_cache_groups = self.kv_cache_config.kv_cache_groups - def _get_pcp_metadata(num_tokens): + def _get_pcp_metadata(block_table_tensor): if not self.use_cp: - return None + return None, block_table_tensor return self.pcp_manager.generate_pcp_metadata( - num_tokens, self.query_lens, self.input_batch, num_scheduled_tokens_np + num_tokens, + self.query_lens, + self.input_batch, + num_scheduled_tokens_np, + block_table_tensor, + num_reqs_padded, + num_reqs, ) def _get_block_table_and_slot_mapping(kv_cache_gid: int): @@ -1843,8 +1849,8 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): ) return blk_table_tensor, slot_mapping - self.long_seq_metadata = _get_pcp_metadata(num_tokens) block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) + self.long_seq_metadata, block_table_gid_0 = _get_pcp_metadata(block_table_gid_0) cm_base = AscendCommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], @@ -2040,11 +2046,14 @@ def _dummy_run( # LoRA state when determining the batch descriptor for capture force_has_lora=activate_lora, ) - if self.pcp_size * self.dcp_size > 1: + if self.use_cp: self.pcp_manager.init_batch_info( num_scheduled_tokens, num_reqs, ) + if self.speculative_config: + self.pcp_manager.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy(num_scheduled_tokens) + self.pcp_manager.query_lens_pcp_full.copy_to_gpu() if cudagraph_runtime_mode is None: cudagraph_runtime_mode = _cudagraph_mode else: diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 9ff3c9e64c3..d2e7e723a66 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -24,6 +24,8 @@ from vllm.config import VllmConfig from vllm.v1.utils import CpuGpuBuffer +from vllm_ascend.worker.npu_input_batch import NPUInputBatch + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -514,13 +516,23 @@ def _get_cp_local_seq_lens( dcp_local_seq_lens = (base + remainder).reshape([-1, pcp_world_size, dcp_world_size]) return dcp_local_seq_lens - def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens): + def generate_pcp_metadata( + self, + total_num_scheduled_tokens: int, + query_lens: torch.Tensor, + input_batch: "NPUInputBatch", + num_scheduled_tokens: np.ndarray | None, + block_table_tensor: torch.Tensor, + num_reqs_padded: int, + num_reqs: int, + ): from vllm_ascend.attention.utils import AscendPrefillContextParallelMetadata num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None if self.pcp_world_size * self.dcp_world_size > 1: + assert num_scheduled_tokens is not None decode_context_lens = ( input_batch.num_computed_tokens_cpu[: self.num_decode_reqs] + num_scheduled_tokens[: self.num_decode_reqs] @@ -544,6 +556,7 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_ba self.vllm_config.parallel_config.cp_kv_cache_interleave_size, ) ) + ori_query_lens_cpu = None if self.decode_threshold > 1: num_computed_tokens_of_pcp_dcp_list = [] if self.num_decode_reqs: @@ -563,10 +576,37 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_ba ] ) num_computed_tokens_of_pcp_dcp = torch.cat(num_computed_tokens_of_pcp_dcp_list, dim=0) + + # For pcp + spec decode, we flatten block_table + # to avoid irregular attn_mask shape, e.g., + # num_decode_req=2, num_prefill_req=3, num_speculative_tokens=1, + # ori block_table: # [d0, d1, p0, p1, p2] + # (num_reqs_d + num_reqs_p, max_num_blocks), + # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] + # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), + ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs_padded] + ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs_padded] + num_prefill_reqs = self.num_prefill_reqs + num_decode_reqs = self.num_decode_reqs + num_decode_reqs_flatten = ori_query_lens_cpu[:num_decode_reqs].sum().item() + block_table_tensor[num_decode_reqs_flatten : num_decode_reqs_flatten + num_prefill_reqs].copy_( + block_table_tensor[num_decode_reqs : num_decode_reqs + num_prefill_reqs].clone() + ) + block_table_tensor[:num_decode_reqs_flatten].copy_( + block_table_tensor[:num_decode_reqs].repeat_interleave(ori_query_lens[:num_decode_reqs], dim=0) + ) + block_table_tensor = block_table_tensor[: num_decode_reqs_flatten + num_prefill_reqs] + if num_reqs_padded > num_reqs: + pad_size = num_reqs_padded - num_reqs + ori_query_lens_cpu[-pad_size:] = torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) + long_seq_metadata = AscendPrefillContextParallelMetadata( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp.numpy(), ) + if ori_query_lens_cpu is not None: + long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu + long_seq_metadata.max_query_len_pcp_full = ori_query_lens_cpu.max().item() if self.pcp_world_size > 1: q_head_idx, q_tail_idx = [], [] kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] @@ -685,8 +725,9 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_ba long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = split_q_tail_nomask_idx_tensor_list long_seq_metadata.head_attn_nomask_seqlens = head_attn_nomask_seqlens_list long_seq_metadata.tail_attn_nomask_seqlens = tail_attn_nomask_seqlens_list + self.long_seq_metadata = long_seq_metadata - return long_seq_metadata + return long_seq_metadata, block_table_tensor def _list_to_tensor(self, lst, device, dtype=torch.int32): tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)