diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py index 78baeffb97e..cc8eb9b57e6 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py @@ -20,17 +20,18 @@ import os import pytest -from tests.e2e.conftest import VllmRunner +from tests.e2e.conftest import VllmRunner, wait_until_npu_memory_free os.environ["HCCL_BUFFSIZE"] = "512" +prompts = [ + "The capital of France is", "Hello, my name is Tom, I am", + "The president of United States is", "AI future is" +] +model = "wemaster/deepseek_mtp_main_random_bf16" +@wait_until_npu_memory_free() def test_pcp_dcp_mtp1_eager(): - prompts = [ - "The capital of France is", "Hello, my name is Tom, I am", - "The president of United States is", "AI future is" - ] - model = "wemaster/deepseek_mtp_main_random_bf16" with VllmRunner( model, max_model_len=1024, @@ -50,15 +51,8 @@ def test_pcp_dcp_mtp1_eager(): runner.generate_greedy(prompts, 32) -@pytest.mark.skip( - reason="vLLM PR-32118 break this", -) +@wait_until_npu_memory_free() def test_pcp_dcp_mtp3_eager(): - prompts = [ - "The capital of France is", "Hello, my name is Tom, I am", - "The president of United States is", "AI future is" - ] - model = "wemaster/deepseek_mtp_main_random_bf16" with VllmRunner( model, max_model_len=1024, @@ -78,15 +72,8 @@ def test_pcp_dcp_mtp3_eager(): runner.generate_greedy(prompts, 32) -@pytest.mark.skip( - reason="vLLM PR-32118 break this", -) +@wait_until_npu_memory_free() def test_pcp_dcp_mtp3_piecewise_graph(): - prompts = [ - "The capital of France is", "Hello, my name is Tom, I am", - "The president of United States is", "AI future is" - ] - model = "wemaster/deepseek_mtp_main_random_bf16" with VllmRunner( model, max_model_len=1024, @@ -109,15 +96,8 @@ def test_pcp_dcp_mtp3_piecewise_graph(): runner.generate_greedy(prompts, 32) -@pytest.mark.skip( - reason="vLLM PR-32118 break this", -) +@wait_until_npu_memory_free() def test_pcp_dcp_mtp3_full_graph(): - prompts = [ - "The capital of France is", "Hello, my name is Tom, I am", - "The president of United States is", "AI future is" - ] - model = "wemaster/deepseek_mtp_main_random_bf16" with VllmRunner( model, max_model_len=1024, @@ -140,12 +120,8 @@ def test_pcp_dcp_mtp3_full_graph(): runner.generate_greedy(prompts, 32) +@wait_until_npu_memory_free() def test_dcp_mtp3_full_graph(): - prompts = [ - "The capital of France is", "Hello, my name is Tom, I am", - "The president of United States is", "AI future is" - ] - model = "wemaster/deepseek_mtp_main_random_bf16" with VllmRunner( model, max_model_len=1024, diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index 3f5ea17a876..af37472962d 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -141,8 +141,10 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, dtype=np.int32) input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32) arange_np = np.arange(10000) + num_scheduled_tokens = np.array(tokens) + pcp_manager.init_batch_info(num_scheduled_tokens, num_reqs) pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp( - np.array(tokens), arange_np, num_reqs, 1) + num_scheduled_tokens, arange_np) assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \ f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}" @@ -305,8 +307,8 @@ def test_generate_pcp_mtp_input( for i, token_ids_tensor in enumerate(token_ids_tensor_list): token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor - pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens, - num_scheduled_tokens, False, + pcp_manager.init_batch_info(np.array(list(num_scheduled_tokens.values())), num_reqs) + pcp_manager.generate_pcp_mtp_input(total_num_scheduled_tokens, num_scheduled_tokens, False, input_batch, arange_np) assert torch.equal( pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 90f453a5043..26286dfe4fd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -563,10 +563,16 @@ def _prepare_inputs( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) + + if self.pcp_size * self.dcp_size > 1: + self.pcp_manager.init_batch_info( + num_scheduled_tokens, + self.input_batch.num_reqs, + ) + # for pcp, prefill mtp should use origin scheduleroutput , if self.speculative_config and self.pcp_size * self.dcp_size > 1: self.pcp_manager.generate_pcp_mtp_input( - num_reqs, total_num_scheduled_tokens, scheduler_output.num_scheduled_tokens, with_prefill, @@ -584,8 +590,6 @@ def _prepare_inputs( num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( num_scheduled_tokens[:num_reqs], self.arange_np, - self.input_batch.num_reqs, - self.reorder_batch_threshold, ) # Re-update after PCP split sequences. total_num_scheduled_tokens = sum(num_scheduled_tokens) @@ -739,8 +743,7 @@ def _prepare_inputs( num_draft_tokens = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) if self.pcp_size * self.dcp_size > 1: - logits_indices = self.pcp_manager.get_logits_indices( - cu_num_tokens, num_reqs) + logits_indices = self.pcp_manager.get_logits_indices(cu_num_tokens) logits_indices = logits_indices.pin_memory().to( self.device, non_blocking=True) else: @@ -987,9 +990,8 @@ def propose_draft_token_ids( num_reqs = self.input_batch.num_reqs ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ query_start_loc_pcp_full_cpu[:num_reqs] - num_prefill_reqs = (ori_query_lens - > self.decode_threshold).sum().item() - num_decode_reqs = num_reqs - num_prefill_reqs + num_prefill_reqs = self.pcp_manager.num_prefill_reqs + num_decode_reqs = self.pcp_manager.num_decode_reqs else: long_seq_metadata = None # type: ignore num_prefill_reqs = 0 @@ -1938,7 +1940,7 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): ) return blk_table_tensor, slot_mapping - long_seq_metdadata = _get_pcp_metadata(num_tokens) + self.long_seq_metadata = _get_pcp_metadata(num_tokens) block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) cm_base = AscendCommonAttentionMetadata( @@ -1963,7 +1965,7 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): positions=self.positions.gpu, attn_state=self.attn_state, decode_token_per_req=self.decode_token_per_req, - prefill_context_parallel_metadata=long_seq_metdadata, + prefill_context_parallel_metadata=self.long_seq_metadata, ) if logits_indices is not None and self.cache_config.kv_sharing_fast_prefill: @@ -2153,6 +2155,11 @@ def _dummy_run( force_has_lora=activate_lora, ) ) + if self.pcp_size * self.dcp_size > 1: + self.pcp_manager.init_batch_info( + num_scheduled_tokens, + num_reqs, + ) if cudagraph_runtime_mode is None: cudagraph_runtime_mode = _cudagraph_mode else: diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 5b221e57200..30dcbba1e7f 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -36,6 +36,10 @@ class PCPManager: This manager encapsulates all PCP-related buffers and logic so that the ModelRunner can access them via `self.pcp_manager`. """ + num_reqs: int = 0 + num_decode_reqs: int = 0 + num_prefill_reqs: int = 0 + num_decode_tokens: int = 0 def __init__( self, @@ -133,12 +137,25 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange + def init_batch_info( + self, + num_scheduled_tokens: np.ndarray, + num_reqs: int, + ) -> None: + self.num_reqs = num_reqs + is_prefill = (num_scheduled_tokens[:num_reqs] > self.decode_threshold) + if not any(is_prefill): + first_prefill = num_reqs + else: + first_prefill = is_prefill.argmax() + self.num_decode_reqs = first_prefill + self.num_prefill_reqs = num_reqs - self.num_decode_reqs + self.num_decode_tokens = num_scheduled_tokens[:self.num_decode_reqs].sum() + def update_tokens_for_pcp( self, num_scheduled_tokens: np.ndarray, arange_np: np.ndarray, - num_reqs: int, - reorder_batch_threshold: int | None = None, ) -> tuple[np.ndarray, np.ndarray]: """ Update token counts and positions for Prefill Context Parallelism (PCP). @@ -167,8 +184,6 @@ def update_tokens_for_pcp( the number of new tokens scheduled per request. arange_np: 1D numpy array of length max_buffer_num_tokens used for efficient batched arange operations. - num_reqs: Total number of requests in the batch. - reorder_batch_threshold: Threshold for decode vs prefill requests. Returns: Tuple (pcp_tokens, pcp_positions): @@ -187,16 +202,10 @@ def update_tokens_for_pcp( >>> self.pcp_unpad_mask_cpu [True, False, True, True, True, True, True, False, False, False, True, True, True, True, True, True, True, True] - >>> self.pcp_allgather_resotre_idx + >>> self.pcp_allgather_restore_idx [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] """ - assert reorder_batch_threshold is not None, ( - "PCP depends on reorder batch to split decode and prefill requests." - ) - num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold) - num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs]) - # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). # We first pad each request's token count up to that multiple. num_padded_scheduled_tokens = np.ceil( @@ -205,11 +214,11 @@ def update_tokens_for_pcp( # PCP does not split decode requests. For decode requests, we instead # duplicate the scheduled tokens across the pcp_world_size ranks. - num_padded_scheduled_tokens[:num_decode_reqs] = ( - num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size) + num_padded_scheduled_tokens[:self.num_decode_reqs] = ( + num_scheduled_tokens[:self.num_decode_reqs] * self.pcp_world_size) # Record how many pads were added per request (padded - original). - self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens - + self.num_pcp_pads_cpu[:self.num_reqs] = (num_padded_scheduled_tokens - num_scheduled_tokens) # cu_padded_tokens: cumulative sum of padded token counts, @@ -221,7 +230,7 @@ def update_tokens_for_pcp( self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( pcp_padded_arange < np.repeat(num_scheduled_tokens, num_padded_scheduled_tokens)) - unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens * + unpad_mask_decode = self.pcp_unpad_mask_cpu[:self.num_decode_tokens * self.pcp_world_size] unpad_mask_decode = unpad_mask_decode.reshape( [-1, self.pcp_world_size]) @@ -233,7 +242,7 @@ def update_tokens_for_pcp( # For prefill requests, we further split the pcp_tokens into two chunks # (head and tail). For decode requests, the chunk equals pcp_tokens. pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) - pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + pcp_chunk_sizes[:self.num_decode_reqs] = pcp_tokens[:self.num_decode_reqs] # Build arange-style helpers for pcp tokens and chunk sizes: # - pcp_arange gives indices repeated for each token in pcp_tokens @@ -271,16 +280,16 @@ def get_current_rank_positions(positions_start_loc: int | np.ndarray, # Fill tail positions. Note decode requests do not have tail chunks, # so the tail filling is only for prefill positions. positions[~pcp_head_chunk_mask] = ( - pcp_chunk_arange[num_decode_tokens:] + - np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]) + pcp_chunk_arange[self.num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[self.num_decode_tokens:]) return positions positions = get_current_rank_positions(0, self.pcp_world_rank) # Decode tokens are duplicated only after AG. But their positions are # same without prefill context parallel. - if num_decode_reqs > 0: - positions[:num_decode_tokens] = self._get_cumsum_and_arange( - num_scheduled_tokens[:num_decode_reqs], arange_np)[1] + if self.num_decode_reqs > 0: + positions[:self.num_decode_tokens] = self._get_cumsum_and_arange( + num_scheduled_tokens[:self.num_decode_reqs], arange_np)[1] # Build the restore index used after allgather. padded_pos_start_loc = np.roll(cu_padded_tokens, 1) @@ -294,27 +303,16 @@ def get_current_rank_positions(positions_start_loc: int | np.ndarray, all_positions.argsort()) self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) - self.pcp_tokens[:num_reqs] = pcp_tokens[:num_reqs] - self.total_num_sampled_tokens_pcp = pcp_tokens[:num_reqs].sum() + self.pcp_tokens[:self.num_reqs] = pcp_tokens[:self.num_reqs] + self.total_num_sampled_tokens_pcp = pcp_tokens[:self.num_reqs].sum() return ( - pcp_tokens[:num_reqs], + pcp_tokens[:self.num_reqs], positions, ) - def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int): + def get_logits_indices(self, cu_num_tokens: np.ndarray): return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size - - self.num_pcp_pads_cpu_tensor[:num_reqs] - 1) - - def get_discard_request_mask( - self, - num_computed_tokens_cpu: np.ndarray, - num_scheduled_tokens: np.ndarray, - num_reqs: int, - num_tokens_np: np.ndarray, - ): - return (num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens * self.pcp_world_size - - self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np + self.num_pcp_pads_cpu_tensor[:self.num_reqs] - 1) def get_padded_slot_mapping(self, num_tokens: int, num_tokens_padded: int, slot_mapping: torch.Tensor): @@ -350,7 +348,6 @@ def get_restore_hidden_states( def generate_pcp_mtp_input( self, - num_reqs: int, total_num_scheduled_tokens: int, num_scheduled_tokens: dict[str, int], with_prefill: bool = True, @@ -369,18 +366,18 @@ def generate_pcp_mtp_input( so we record original input_ids here. """ total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens - num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) + num_scheduled_tokens_pcp_full = np.empty(self.num_reqs, dtype=np.int32) for i, req_id in enumerate(input_batch.req_ids): num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] - self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy( + self.query_lens_pcp_full.cpu[:self.num_reqs] = torch.from_numpy( num_scheduled_tokens_pcp_full) - req_indices_pcp_full = np.repeat(arange_np[:num_reqs], + req_indices_pcp_full = np.repeat(arange_np[:self.num_reqs], num_scheduled_tokens_pcp_full) cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) self.query_start_loc_pcp_full.np[0] = 0 - self.query_start_loc_pcp_full.np[1:num_reqs + + self.query_start_loc_pcp_full.np[1:self.num_reqs + 1] = cu_num_tokens_pcp_full - self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) + self.query_start_loc_pcp_full.np[self.num_reqs + 1:].fill(-1) cumsums_offsets_pcp_full = np.repeat( cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, num_scheduled_tokens_pcp_full) @@ -413,13 +410,13 @@ def generate_pcp_mtp_input( if self.decode_threshold > 2 and not with_prefill: num_tokens_ori = sum(list(num_scheduled_tokens.values())) num_tokens_mtp = \ - num_tokens_ori + num_reqs * (self.decode_threshold - 2) + num_tokens_ori + self.num_reqs * (self.decode_threshold - 2) num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size req_indices_split = np.array_split(req_indices, - cu_num_tokens)[:num_reqs] + cu_num_tokens)[:self.num_reqs] positions_split = np.array_split(positions_np, - cu_num_tokens)[:num_reqs] - for req_idx in range(num_reqs): + cu_num_tokens)[:self.num_reqs] + for req_idx in range(self.num_reqs): ori_req_indice = req_indices_split[req_idx] ori_position = positions_split[req_idx] req_indices_split[req_idx] = np.append( @@ -567,25 +564,20 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens): from vllm_ascend.attention.utils import \ AscendPrefillContextParallelMetadata - num_reqs = input_batch.num_reqs or query_lens.size(0) - query_lens_new = self.query_lens_pcp_full.cpu[:num_reqs] \ - if self.pcp_world_size > 1 and self.speculative_config else query_lens - num_decodes = (query_lens_new <= self.decode_threshold).sum().item() - num_prefills = num_reqs - num_decodes num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None if self.pcp_world_size * self.dcp_world_size > 1: decode_context_lens = input_batch.num_computed_tokens_cpu[: - num_decodes] + num_scheduled_tokens[: - num_decodes] + self.num_decode_reqs] + num_scheduled_tokens[: + self.num_decode_reqs] prefill_context_lens = input_batch.num_computed_tokens_cpu[ - num_decodes:num_reqs] + self.num_decode_reqs:self.num_reqs] context_lens = np.concatenate( [decode_context_lens, prefill_context_lens]) num_computed_tokens_of_pcp_dcp = torch.zeros( [ - num_reqs * self.decode_threshold, self.pcp_world_size, + self.num_reqs * self.decode_threshold, self.pcp_world_size, self.dcp_world_size ], dtype=torch.int32, @@ -605,24 +597,24 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, ) if self.decode_threshold > 1: num_computed_tokens_of_pcp_dcp_list = [] - if num_decodes: + if self.num_decode_reqs: num_decodes_flatten = \ - query_lens[:num_decodes].sum().item() - if query_lens[:num_decodes].min().item( + query_lens[:self.num_decode_reqs].sum().item() + if query_lens[:self.num_decode_reqs].min().item( ) == self.decode_threshold: decode_flatten_idx = list(range(num_decodes_flatten)) else: decode_flatten_idx = [] - for req_id in range(num_decodes): + for req_id in range(self.num_decode_reqs): offset = (req_id + 1) * self.decode_threshold decode_flatten_idx += \ list(range(offset - query_lens[req_id], offset)) num_computed_tokens_of_pcp_dcp_list.append( num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) - if num_prefills: + if self.num_prefill_reqs: num_computed_tokens_of_pcp_dcp_list.append( num_computed_tokens_of_pcp_dcp[ - (num_decodes + 1) * self.decode_threshold - + (self.num_decode_reqs + 1) * self.decode_threshold - 1::self.decode_threshold]) num_computed_tokens_of_pcp_dcp = torch.cat( num_computed_tokens_of_pcp_dcp_list, dim=0) @@ -643,7 +635,7 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, q_head_chunk_id = self.pcp_world_rank q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank for i, seq_len in enumerate(query_lens): - if i < num_decodes: + if i < self.num_decode_reqs: continue chunk_len = seq_len // 2 chunk_seqlens.append(chunk_len)