From 3f64d8e141f6d95d4ad8a629954258df3b65880a Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Fri, 12 Dec 2025 19:02:51 +0800 Subject: [PATCH 01/43] [bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c6c881ebc30..8f39a4bf38a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1874,8 +1874,10 @@ def propose_draft_token_ids(sampled_token_ids): return AsyncGPUModelRunnerOutput( model_runner_output=model_runner_output, sampled_token_ids=sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, + vocab_size=self.input_batch.vocab_size, ) def _build_dummy_attn_metadata( From b634097ec3080f0e176cef89da387b17ab860ed1 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Fri, 12 Dec 2025 22:49:03 +0800 Subject: [PATCH 02/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- .github/workflows/_e2e_test.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 213793d20ff..b665e5fd50c 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -93,6 +93,7 @@ jobs: pytest -sv tests/e2e/singlecard/test_completion_with_prompt_embeds.py pytest -sv tests/e2e/singlecard/test_aclgraph_accuracy.py pytest -sv tests/e2e/singlecard/test_aclgraph_mem.py + pytest -sv tests/e2e/singlecard/test_async_scheduling.py pytest -sv tests/e2e/singlecard/test_camem.py pytest -sv tests/e2e/singlecard/test_guided_decoding.py # torch 2.8 doesn't work with lora, fix me From 0edb42c1fe86cf97f52a88f9938fce34ae41bfd7 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 13 Dec 2025 10:56:19 +0800 Subject: [PATCH 03/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4ac423dac4a..02a5acae1a4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3474,7 +3474,7 @@ def __init__(self, *args, **kwargs) -> None: try: # replace cuda APIs with xpu APIs, this should work by default - torch.cuda.Event = _EventPlaceholder + torch.cuda.Event = torch.npu.Event torch.cuda.Stream = torch.npu.Stream torch.cuda.default_stream = torch.npu.default_stream torch.cuda.current_stream = torch.npu.current_stream From 52d66d357497b93fc761eb3ea7404eaee1fa3cb6 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 13 Dec 2025 11:35:22 +0800 Subject: [PATCH 04/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- tests/e2e/singlecard/test_async_scheduling.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/e2e/singlecard/test_async_scheduling.py b/tests/e2e/singlecard/test_async_scheduling.py index 4f4eb05fb5f..af598c33078 100644 --- a/tests/e2e/singlecard/test_async_scheduling.py +++ b/tests/e2e/singlecard/test_async_scheduling.py @@ -45,25 +45,25 @@ def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ): run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) -def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): - """Test consistency and acceptance rates with some different combos of - preemption, executor, async scheduling, prefill chunking, - spec decoding model length. - """ - - spec_config = { - "method": "mtp", - "num_speculative_tokens": 2, - } - - # test_preemption, executor, async_scheduling, - # spec_config, test_prefill_chunking - test_configs = [ - (False, "mp", True, spec_config, False), - (False, "mp", False, spec_config, False), - ] - - run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) +# def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): +# """Test consistency and acceptance rates with some different combos of +# preemption, executor, async scheduling, prefill chunking, +# spec decoding model length. +# """ + +# spec_config = { +# "method": "mtp", +# "num_speculative_tokens": 2, +# } + +# # test_preemption, executor, async_scheduling, +# # spec_config, test_prefill_chunking +# test_configs = [ +# (False, "mp", True, spec_config, False), +# (False, "mp", False, spec_config, False), +# ] + +# run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) @dynamo_config.patch(cache_size_limit=16) From e01e0a4c0250acf3c3fac5d7911ecf2643bd6c46 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 13 Dec 2025 11:41:55 +0800 Subject: [PATCH 05/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- tests/e2e/singlecard/test_async_scheduling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/e2e/singlecard/test_async_scheduling.py b/tests/e2e/singlecard/test_async_scheduling.py index af598c33078..fd6e76d25f5 100644 --- a/tests/e2e/singlecard/test_async_scheduling.py +++ b/tests/e2e/singlecard/test_async_scheduling.py @@ -17,8 +17,12 @@ first_prompt = ("The following numbers of the sequence " + ", ".join(str(i) for i in range(10)) + " are:") -example_prompts = [first_prompt, "In one word, the capital of France is " - ] + [f"Tell me about the number {i}: " for i in range(32)] +example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] default_params = dict( temperature=0.0, # greedy From 1bec2317c26646b1dc7d9fdd4c4293cf9ad046d9 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 13 Dec 2025 11:42:58 +0800 Subject: [PATCH 06/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- tests/e2e/singlecard/test_async_scheduling.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/e2e/singlecard/test_async_scheduling.py b/tests/e2e/singlecard/test_async_scheduling.py index fd6e76d25f5..05f7bdd31be 100644 --- a/tests/e2e/singlecard/test_async_scheduling.py +++ b/tests/e2e/singlecard/test_async_scheduling.py @@ -49,25 +49,25 @@ def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ): run_tests(monkeypatch, MODEL, test_configs, test_sampling_params) -# def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): -# """Test consistency and acceptance rates with some different combos of -# preemption, executor, async scheduling, prefill chunking, -# spec decoding model length. -# """ - -# spec_config = { -# "method": "mtp", -# "num_speculative_tokens": 2, -# } - -# # test_preemption, executor, async_scheduling, -# # spec_config, test_prefill_chunking -# test_configs = [ -# (False, "mp", True, spec_config, False), -# (False, "mp", False, spec_config, False), -# ] - -# run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) +def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): + """Test consistency and acceptance rates with some different combos of + preemption, executor, async scheduling, prefill chunking, + spec decoding model length. + """ + + spec_config = { + "method": "mtp", + "num_speculative_tokens": 2, + } + + # test_preemption, executor, async_scheduling, + # spec_config, test_prefill_chunking + test_configs = [ + (False, "mp", True, spec_config, False), + (False, "mp", False, spec_config, False), + ] + + run_tests(monkeypatch, MTP_MODEL, test_configs, [{}]) @dynamo_config.patch(cache_size_limit=16) From 1df5994e54f5d441369f4b17d2f41ab1da1b9899 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 13 Dec 2025 11:54:26 +0800 Subject: [PATCH 07/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- tests/e2e/singlecard/test_async_scheduling.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/e2e/singlecard/test_async_scheduling.py b/tests/e2e/singlecard/test_async_scheduling.py index 05f7bdd31be..aab24911de6 100644 --- a/tests/e2e/singlecard/test_async_scheduling.py +++ b/tests/e2e/singlecard/test_async_scheduling.py @@ -18,11 +18,11 @@ first_prompt = ("The following numbers of the sequence " + ", ".join(str(i) for i in range(10)) + " are:") example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] default_params = dict( temperature=0.0, # greedy From 690b76594f1b665d3af02eac0234d2fabdea63e7 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 13 Dec 2025 12:45:51 +0800 Subject: [PATCH 08/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- tests/e2e/singlecard/test_async_scheduling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/singlecard/test_async_scheduling.py b/tests/e2e/singlecard/test_async_scheduling.py index aab24911de6..fde1db580ed 100644 --- a/tests/e2e/singlecard/test_async_scheduling.py +++ b/tests/e2e/singlecard/test_async_scheduling.py @@ -30,7 +30,6 @@ min_tokens=18, ) - def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ): """Test consistency of combos of async scheduling, preemption, uni/multiproc executor, prefill chunking.""" From 6f28654e91a47913d1cf03b8ca85d978a4c311b3 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 13 Dec 2025 14:05:49 +0800 Subject: [PATCH 09/43] [Bugfix] asyncscheduler bug fix Signed-off-by: zhenwenqi2024 --- tests/e2e/singlecard/test_async_scheduling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/e2e/singlecard/test_async_scheduling.py b/tests/e2e/singlecard/test_async_scheduling.py index fde1db580ed..aab24911de6 100644 --- a/tests/e2e/singlecard/test_async_scheduling.py +++ b/tests/e2e/singlecard/test_async_scheduling.py @@ -30,6 +30,7 @@ min_tokens=18, ) + def test_without_spec_decoding(monkeypatch: pytest.MonkeyPatch, ): """Test consistency of combos of async scheduling, preemption, uni/multiproc executor, prefill chunking.""" From 2774391ecb160f370ed27d28608dd2831934a1cb Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 20 Dec 2025 15:45:27 +0800 Subject: [PATCH 10/43] [Feature] refactor model_runner for pcp & dcp Signed-off-by: zhenwenqi2024 --- vllm_ascend/spec_decode/eagle_proposer.py | 6 +- vllm_ascend/spec_decode/mtp_proposer.py | 10 +- vllm_ascend/utils.py | 599 ++++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 637 ++++++---------------- 4 files changed, 759 insertions(+), 493 deletions(-) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 356efad6517..280c9f66430 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -202,9 +202,9 @@ def generate_token_ids(self, req_scheduled_tokens = scheduler_output.num_scheduled_tokens if self.pcp_size > 1: long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.input_ids_pcp_full - query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full - query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu + input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu num_reqs = self.runner.input_batch.num_reqs ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ query_start_loc_pcp_full_cpu[:num_reqs] diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 4deef9a2987..837b41d5ca1 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -371,9 +371,9 @@ def generate_token_ids(self, req_scheduled_tokens = scheduler_output.num_scheduled_tokens if self.pcp_size > 1: long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.input_ids_pcp_full - query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full - query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full_cpu + input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu num_reqs = self.runner.input_batch.num_reqs ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ query_start_loc_pcp_full_cpu[:num_reqs] @@ -820,8 +820,8 @@ def _propose( if self.pcp_size > 1: hidden_states = get_pcp_group().all_gather(hidden_states, 0) hidden_states = torch.index_select( - hidden_states, 0, self.runner. - pcp_allgather_restore_idx[:hidden_states.shape[0]]) + hidden_states, 0, self.runner.pcp_manager. + pcp_allgather_restore_idx.gpu[:hidden_states.shape[0]]) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e7b3b8ecedf..de8d9476b84 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -26,12 +26,17 @@ from threading import Lock from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +import numpy as np import torch import torch_npu # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event from vllm.logger import logger from vllm.sequence import IntermediateTensors +# from vllm.distributed.parallel_state import get_pcp_group +from vllm.utils.math_utils import cdiv +# from vllm.config import VllmConfig +from vllm.v1.utils import CpuGpuBuffer import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -1086,3 +1091,597 @@ def dispose_layer(layer: Any): def replace_layer(original_layer: Any, new_layer: Any): original_layer.__class__ = new_layer.__class__ original_layer.__dict__ = new_layer.__dict__ + + +class PCPManager: + """ + Manager for Prefill Context Parallelism (PCP) metadata and buffers. + + This manager encapsulates all PCP-related buffers and logic so that the + ModelRunner can access them via `self.pcp_manager`. + """ + + def __init__( + self, + pcp_world_size: int, + pcp_rank: int, + dcp_world_size: int, + dcp_rank: int, + max_buffer_num_tokens: int, + max_num_reqs: int, + # speculative_config, + device: torch.device, + vllm_config: VllmConfig, + pin_memory: bool = False, + ) -> None: + self.pcp_world_size = pcp_world_size + self.pcp_world_rank = pcp_rank + self.dcp_world_size = dcp_world_size + self.dcp_world_rank = dcp_rank + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + ( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + self.vllm_config = vllm_config + self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + self.device = device + self.pcp_allgather_restore_idx = CpuGpuBuffer( + max_buffer_num_tokens, + dtype=torch.int64, + device=device, + pin_memory=pin_memory, + ) + self.pcp_padded_slot_mapping = torch.empty( + (max_buffer_num_tokens, ), + dtype=torch.int32, + device=device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ), + device="cpu", + dtype=torch.int64) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (max_buffer_num_tokens, ), + device="cpu", + dtype=torch.bool, + ) + self.num_actual_tokens_pcp_padded = 0 + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ + [] for _ in range(self.pcp_world_size) + ] + self.full_indices = list( + range(self.max_num_tokens * self.pcp_world_size * + self.dcp_world_size + self.pcp_world_size * + self.dcp_world_size * self.max_num_reqs)) + if self.speculative_config and self.pcp_world_size > 1: + self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens, + device=device, + pin_memory=pin_memory, + dtype=torch.int32) + self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1, + device=device, + pin_memory=pin_memory, + dtype=torch.int32) + self.positions_pcp_full = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=True) + self.positions_pcp_full_np = self.positions_pcp_full.numpy() + + def _get_cumsum_and_arange( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + cumsum_dtype: np.dtype | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, + num_scheduled_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def update_tokens_for_pcp( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + num_reqs: int, + reorder_batch_threshold: int | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """ + Update token counts and positions for Prefill Context Parallelism (PCP). + + When using Prefill Context Parallelism, each request's prefill sequence is + split across multiple PCP ranks. The splitting strategy used here is the + "DualChunkSwap" style: each request's (padded) sequence is split into + 2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved + head/tail pattern to balance load. + + This function: + - Computes how many tokens each request should be processed by the current + PCP rank (pcp_tokens). + - Computes the flattened positions of those tokens within the local + padded buffer (pcp_positions). + - Updates runner state arrays used to restore original order and mask out + padded tokens after allgather: + - self.num_pcp_pads_cpu: number of pads added per request + - self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the + padded allgather buffer + - self.pcp_allgather_restore_idx: index array used to restore original + ordering after per-rank allgather and interleaving. + + Args: + num_scheduled_tokens: 1D numpy array of length num_reqs containing + the number of new tokens scheduled per request. + arange_np: 1D numpy array of length max_buffer_num_tokens used for + efficient batched arange operations. + num_reqs: Total number of requests in the batch. + reorder_batch_threshold: Threshold for decode vs prefill requests. + + Returns: + Tuple (pcp_tokens, pcp_positions): + - pcp_tokens: number of tokens per request that this PCP rank will + actually process (after splitting / replication). + - pcp_positions: flattened positions for those tokens on this rank, + used to build the positions buffer for the model. + + Example: + >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. + >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> Meanwhile, the following results are same for each pcp rank + >>> self.num_pcp_pads_cpu + [1, 3, 0] + >>> self.pcp_unpad_mask_cpu + [True, False, True, True, True, True, True, False, False, + False, True, True, True, True, True, True, True, True] + >>> self.pcp_allgather_resotre_idx + [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + """ + + assert reorder_batch_threshold is not None, ( + "PCP depends on reorder batch to split decode and prefill requests." + ) + num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold) + num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs]) + + # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). + # We first pad each request's token count up to that multiple. + num_padded_scheduled_tokens = np.ceil( + num_scheduled_tokens / (2 * self.pcp_world_size)).astype( + np.int32) * (2 * self.pcp_world_size) + + # PCP does not split decode requests. For decode requests, we instead + # duplicate the scheduled tokens across the pcp_world_size ranks. + num_padded_scheduled_tokens[:num_decode_reqs] = ( + num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size) + + # Record how many pads were added per request (padded - original). + self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens - + num_scheduled_tokens) + + # cu_padded_tokens: cumulative sum of padded token counts, + # pcp_padded_arange: per-request arange flattened for padded tokens. + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( + num_padded_scheduled_tokens, arange_np) + # Build the mask that marks which positions in the padded allgather buffer + # correspond to real (unpadded) tokens. + self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( + pcp_padded_arange < np.repeat(num_scheduled_tokens, + num_padded_scheduled_tokens)) + + pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + + # Compute per-request "chunk sizes" for the head/tail splitting. + # For prefill requests, we further split the pcp_tokens into two chunks + # (head and tail). For decode requests, the chunk equals pcp_tokens. + pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + + # Build arange-style helpers for pcp tokens and chunk sizes: + # - pcp_arange gives indices repeated for each token in pcp_tokens + # - pcp_chunk_arange gives indices repeated for each position inside chunks + _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np) + _, pcp_chunk_arange = self._get_cumsum_and_arange( + pcp_chunk_sizes, arange_np) + + # Mask that marks whether a position belongs to the head chunk (True) + # or the tail chunk (False). For decode requests, tail chunk won't exist + # and is handled specially below. + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, + pcp_tokens) + + def get_current_rank_positions(positions_start_loc: int | np.ndarray, + rank: int): + """ + Compute flattened positions for the given rank with a given start + offset for each request (positions_start_loc). + + - For head chunks: start at positions_start_loc + rank * chunk_size. + - For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank - + 1) * chunk_size. + - For decode requests: no tail chunks; their positions are filled from the + contiguous (unpadded) `tokens` arange instead (handled after). + """ + positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * pcp_chunk_sizes + tail_start_loc = ( + positions_start_loc + + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes) + # Fill head positions using chunk arange offset by head_start_loc. + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( + head_start_loc, pcp_chunk_sizes) + # Fill tail positions. Note decode requests do not have tail chunks, + # so the tail filling is only for prefill positions. + positions[~pcp_head_chunk_mask] = ( + pcp_chunk_arange[num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]) + return positions + + positions = get_current_rank_positions(0, self.pcp_world_rank) + # Decode tokens are duplicated only after AG. But their positions are + # same without prefill context parallel. + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + num_scheduled_tokens[:num_decode_reqs], arange_np)[1] + + # Build the restore index used after allgather. + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) + for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + logger.info(f"all_positions is {all_positions}") + self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = ( + all_positions.argsort()) + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + + return ( + pcp_tokens[:num_reqs], + positions, + ) + + def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int): + return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size - + self.num_pcp_pads_cpu_tensor[:num_reqs] - 1) + + def get_discard_request_mask( + self, + num_computed_tokens_cpu: np.ndarray, + num_scheduled_tokens: np.ndarray, + num_reqs: int, + num_tokens_np: np.ndarray, + ): + return (num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens * self.pcp_world_size - + self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np + + def get_padded_slot_mapping(self, num_tokens: list, + slot_mapping: torch.Tensor): + # After pcp allgather and restore, there are padded tokens in kv, + # so we need pad slotmapping for alignment. + # print(f"num_actual_tokens_pcp_padded: {self.num_actual_tokens_pcp_padded}") + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens * + self. + pcp_world_size] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * + self.pcp_world_size] + pcp_padded_slot_mapping.fill_(-1) + pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + return pcp_padded_slot_mapping + + def get_restore_hidden_states( + self, + hidden_states: torch.Tensor, + ): + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + from vllm.distributed.parallel_state import get_pcp_group + hidden_states = get_pcp_group().all_gather( + hidden_states[:self.num_actual_tokens_pcp_padded // + self.pcp_world_size], + 0, + ) + restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states. + shape[0]] + return torch.index_select( + hidden_states, + 0, + restore_idx, + ) + + def generate_pcp_mtp_input( + self, + num_reqs: int, + total_num_scheduled_tokens: int, + num_scheduled_tokens: dict[str, int], + input_batch, + arange_np: np.ndarray, + ): + """ + While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, + but mtp need to shift original input_ids before pcp splitting, + so we record original input_ids here. + """ + total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens + num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) + for i, req_id in enumerate(input_batch.req_ids): + num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] + req_indices_pcp_full = np.repeat(arange_np[:num_reqs], + num_scheduled_tokens_pcp_full) + cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) + self.query_start_loc_pcp_full.np[0] = 0 + self.query_start_loc_pcp_full.np[1:num_reqs + + 1] = cu_num_tokens_pcp_full + self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) + cumsums_offsets_pcp_full = np.repeat( + cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, + num_scheduled_tokens_pcp_full) + arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full + positions_pcp_full_np = self.positions_pcp_full_np[: + total_num_scheduled_tokens_pcp_full] + np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], + arange_pcp_full, + out=positions_pcp_full_np) + token_indices_pcp_full = ( + positions_pcp_full_np + + req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]) + torch.index_select(input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices_pcp_full), + out=self.input_ids_pcp_full. + cpu[:total_num_scheduled_tokens_pcp_full]) + self.query_start_loc_pcp_full.copy_to_gpu(num_reqs + 1) + self.input_ids_pcp_full.copy_to_gpu( + total_num_scheduled_tokens_pcp_full) + + def _get_cp_local_seq_lens( + self, + seq_lens: torch.Tensor, + pcp_world_size: int = 1, + dcp_world_size: int = 1, + cp_kv_cache_interleave_size: int = 1, + ) -> torch.Tensor: + """While using pcp or dcp, kv_cache size stored on each rank may be different, + use this function to calculate split decode seq_lens of each (p/d)cp rank. + """ + num_requests = seq_lens.size(0) + total_world_size = pcp_world_size * dcp_world_size + seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) + rank_offsets = (torch.arange(total_world_size, + dtype=torch.int32).unsqueeze(0).repeat( + num_requests, 1)) + base = (seq_lens_tiled // cp_kv_cache_interleave_size // + total_world_size * cp_kv_cache_interleave_size) + remainder = seq_lens_tiled - base * total_world_size + remainder = torch.clip( + remainder - rank_offsets * cp_kv_cache_interleave_size, + 0, + cp_kv_cache_interleave_size, + ) + dcp_local_seq_lens = (base + remainder).reshape( + [-1, pcp_world_size, dcp_world_size]) + return dcp_local_seq_lens + + def generate_kv_idx(self, scheduler_output, input_batch): + if not self.pcp_world_size > 1: + return + self.cp_kv_recover_idx_for_chunk = [[] + for _ in range(self.pcp_world_size) + ] + + for i, req_id in enumerate(input_batch.req_ids): + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + is_prefill = input_batch.num_computed_tokens_cpu[ + i] < input_batch.num_prompt_tokens[i] + if is_prefill: + num_cp_padded_scheduled_tokens = cdiv( + num_scheduled_tokens, + 2 * self.pcp_world_size) * (2 * self.pcp_world_size) + chunk_size = num_cp_padded_scheduled_tokens // ( + 2 * self.pcp_world_size) + num_added_recover_tokens = len( + self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size + for rank in range(self.pcp_world_size): + self.cp_kv_recover_idx_for_chunk[rank].extend( + self.full_indices[rank * chunk_size + + num_added_recover_tokens:(rank + 1) * + chunk_size + + num_added_recover_tokens]) + self.cp_kv_recover_idx_for_chunk[rank].extend( + self.full_indices[num_cp_padded_scheduled_tokens - + (rank + 1) * chunk_size + + num_added_recover_tokens: + num_cp_padded_scheduled_tokens - + rank * chunk_size + + num_added_recover_tokens]) + + cp_kv_recover_idx_for_chunk = torch.from_numpy( + np.concatenate(self.cp_kv_recover_idx_for_chunk)).to( + device=self.device) + logger.info( + f"cp_kv_recover_idx_for_chunk is {cp_kv_recover_idx_for_chunk}" + ) + cp_kv_recover_idx_for_chunk.copy_(torch.tensor( + np.array( + self.cp_kv_recover_idx_for_chunk).flatten().tolist()), + non_blocking=True) + logger.info( + f"cp_kv_recover_idx_for_chunk22222 is {cp_kv_recover_idx_for_chunk}" + ) + self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( + torch.float32).argsort().to(torch.int32) + + def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, + attn_mask, input_batch): + from vllm_ascend.attention.utils import \ + AscendPrefillContextParallelMetadata + num_reqs = input_batch.num_reqs or self.query_lens.size(0) + num_decodes = sum(input_batch.num_computed_tokens_cpu[:num_reqs] >= + input_batch.num_prompt_tokens[:num_reqs]) + num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size + self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded + long_seq_metadata = None + if self.pcp_world_size * self.dcp_world_size > 1: + decode_context_lens = input_batch.num_tokens[:num_decodes] + prefill_context_lens = input_batch.num_computed_tokens_cpu[ + num_decodes:num_reqs] + context_lens = np.concatenate( + [decode_context_lens, prefill_context_lens]) + num_computed_tokens_of_pcp_dcp = torch.zeros( + [ + num_reqs * self.decode_threshold, self.pcp_world_size, + self.dcp_world_size + ], + dtype=torch.int32, + ) + # For pcp + spec decode, we flatten seq_lens + # to avoid irregular spec_attn_mask shape + for decode_idx in range(self.decode_threshold): + num_computed_tokens_of_pcp_dcp[ + self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ + self._get_cp_local_seq_lens( + torch.tensor(context_lens), + self.pcp_world_size, + self.dcp_world_size, + self.vllm_config.parallel_config.cp_kv_cache_interleave_size, + ) + long_seq_metadata = AscendPrefillContextParallelMetadata( + num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, + num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. + numpy()) + if self.pcp_world_size > 1: + q_head_idx, q_tail_idx = [], [] + kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] + kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] + chunk_seqlens = [] + kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] + q_req_offset = 0 + kv_req_offset = 0 + q_head_chunk_id = self.pcp_world_rank + q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank + for i, seq_len in enumerate(query_lens): + if i < num_decodes: + continue + chunk_len = seq_len // 2 + chunk_seqlens.append(chunk_len) + q_head_idx.extend( + list(range(q_req_offset, q_req_offset + chunk_len))) + kv_with_q_head_nomask_idx.extend( + list( + range(kv_req_offset, kv_req_offset + + chunk_len * q_head_chunk_id))) + kv_with_q_head_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_head_chunk_id, + kv_req_offset + chunk_len * + (q_head_chunk_id + 1)))) + kv_with_q_head_nomask_seqlens.append(chunk_len * + q_head_chunk_id) + + q_tail_idx.extend( + list( + range(q_req_offset + chunk_len, + q_req_offset + chunk_len * 2))) + kv_with_q_tail_nomask_idx.extend( + list( + range(kv_req_offset, kv_req_offset + + chunk_len * q_tail_chunk_id))) + kv_with_q_tail_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_tail_chunk_id, + kv_req_offset + chunk_len * + (q_tail_chunk_id + 1)))) + kv_with_q_tail_nomask_seqlens.append(chunk_len * + q_tail_chunk_id) + + q_req_offset += seq_len + kv_req_offset += seq_len * self.pcp_world_size + + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), + dtype=dtype, + device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) + self.q_head_idx_tensor = q_head_idx_tensor + self.q_tail_idx_tensor = q_tail_idx_tensor + + q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) + q_full_idx = q_full_idx.to(torch.float32).argsort().to( + torch.int32) + self.q_full_idx = q_full_idx + + self.kv_idx_names = { + 'kv_with_q_head_nomask_idx_tensor': + kv_with_q_head_nomask_idx, + 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, + 'kv_with_q_tail_nomask_idx_tensor': + kv_with_q_tail_nomask_idx, + 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + } + for key, value in self.kv_idx_names.items(): + tensor_npu = _list_to_tensor(value, self.device) + self.kv_idx_names[key] = tensor_npu + + attn_mask_seqlens = torch.tensor( + [chunk_seqlens, chunk_seqlens], dtype=torch.int32) + head_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_head_nomask_seqlens], + dtype=torch.int32) + tail_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], + dtype=torch.int32) + pcp_prefill_mask = attn_mask + + self.extra_long_seq_kwargs = { + 'attn_mask_seqlens': attn_mask_seqlens, + 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, + 'pcp_prefill_mask': pcp_prefill_mask + } + long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: + num_actual_tokens_pcp_padded] + long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk + long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor + long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor + long_seq_metadata.q_full_idx = self.q_full_idx + long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_head_nomask_idx_tensor'] + long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_head_mask_idx_tensor'] + long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_tail_nomask_idx_tensor'] + long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_tail_mask_idx_tensor'] + long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ + 'attn_mask_seqlens'] + long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ + 'head_attn_nomask_seqlens'] + long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ + 'tail_attn_nomask_seqlens'] + long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ + 'pcp_prefill_mask'] + self.long_seq_metadata = long_seq_metadata + return long_seq_metadata diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 19e8a310277..f7920674094 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -24,7 +24,7 @@ from copy import copy, deepcopy from dataclasses import dataclass from multiprocessing import Manager -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Union import numpy as np import regex as re @@ -86,8 +86,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - AscendPrefillContextParallelMetadata) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata # yapf conflicts with isort for this block # yapf: disable from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, @@ -114,9 +113,11 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer -from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, - enable_sp, get_ascend_device_type, is_moe_model, - lmhead_tp_enable, maybe_trans_nz) +from vllm_ascend.utils import (AscendDeviceType, PCPManager, + ProfileExecuteDuration, enable_sp, + get_ascend_device_type, is_moe_model, + lmhead_tp_enable, maybe_trans_nz, + vllm_version_is) from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.ascend_forward_context import ( # isort: skip @@ -210,6 +211,27 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.pcp_rank = 0 if self.pcp_size > 1: self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs + max_buffer_num_tokens = self.max_num_tokens + logger.info(f"self.max_num_tokens is {self.max_num_tokens}") + if self.pcp_size > 1: + max_buffer_num_tokens = (self.max_num_tokens + + self.max_num_reqs * 2 * self.pcp_size) + self.pcp_manager = PCPManager( + self.pcp_size, + self.pcp_rank, + self.dcp_size, + self.dcp_rank, + max_buffer_num_tokens, + self.max_num_reqs, + self.device, + self.vllm_config, + self.pin_memory, + ) + # TODO(zhenwenqi) after https://github.com/vllm-project/vllm/pull/28988 is merged, we can delete this + self.input_ids = self._make_buffer(max_buffer_num_tokens, + dtype=torch.int32) + self.positions = self._make_buffer(max_buffer_num_tokens, + dtype=torch.int64) if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: self.prefetch_stream = torch.npu.Stream(device=device) else: @@ -242,15 +264,24 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # Set up Attention self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, "index_topk") - self.attn_backend = get_attn_backend( - 0, - self.dtype, - None, - self.block_size, - use_mla=self.model_config.use_mla, - use_sparse=self.use_sparse, - use_mm_prefix=self.model_config is not None - and self.model_config.is_mm_prefix_lm) + if vllm_version_is('0.12.0'): + self.attn_backend = get_attn_backend( + 0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse) + else: + self.attn_backend = get_attn_backend( + 0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse, + use_mm_prefix=self.model_config is not None + and self.model_config.is_mm_prefix_lm) self.attn_mask_builder = AttentionMaskBuilder(self.device) self._set_up_drafter() @@ -267,31 +298,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): set_mc2_tokens_capacity(vllm_config, self.max_num_reqs, self.uniform_decode_query_len) set_mc2_mask(vllm_config, self.device) - self.pcp_allgather_restore_idx = torch.zeros( - self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ - [] for _ in range(self.pcp_size) - ] - - self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32) - self.pcp_padded_slot_mapping = torch.zeros( - self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.num_actual_tokens_pcp_padded = 0 - if self.speculative_config and self.pcp_size > 1: - self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.query_start_loc_pcp_full = self._make_buffer( - self.max_num_reqs + 1, dtype=torch.int32) - self.positions_pcp_full = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=True) - self.decode_token_per_req += self.speculative_config.num_speculative_tokens - self.positions_pcp_full_np = self.positions_pcp_full.numpy() self.decode_threshold = 1 + ( self.speculative_config.num_speculative_tokens if self.speculative_config else 0) @@ -491,49 +497,6 @@ def _make_attention_mask(self, attn_state) -> torch.Tensor: return self.attn_mask_builder.get_mla_mask(self.dtype) return self.attn_mask_builder.get_splitfuse_attn_mask() - def generate_kv_idx(self, scheduler_output): - if not self.pcp_size > 1: - return - self.cp_kv_recover_idx_for_chunk = [[] for _ in range(self.pcp_size)] - - for i, req_id in enumerate(self.input_batch.req_ids): - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - is_prefill = self.input_batch.num_computed_tokens_cpu[ - i] < self.input_batch.num_prompt_tokens[i] - if is_prefill: - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_tokens, - 2 * self.pcp_size) * (2 * self.pcp_size) - full_indices = list( - range(self.max_num_tokens * self.pcp_size * self.dcp_size + - self.pcp_size * self.dcp_size * self.max_num_reqs)) - chunk_size = num_cp_padded_scheduled_tokens // (2 * - self.pcp_size) - num_added_recover_tokens = len( - self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_size - for rank in range(self.pcp_size): - self.cp_kv_recover_idx_for_chunk[rank].extend( - full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * - chunk_size + num_added_recover_tokens]) - self.cp_kv_recover_idx_for_chunk[rank].extend( - full_indices[num_cp_padded_scheduled_tokens - - (rank + 1) * chunk_size + - num_added_recover_tokens: - num_cp_padded_scheduled_tokens - - rank * chunk_size + - num_added_recover_tokens]) - - cp_kv_recover_idx_for_chunk = torch.from_numpy( - np.concatenate( - self.cp_kv_recover_idx_for_chunk)).to(device=self.device) - cp_kv_recover_idx_for_chunk.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( - torch.float32).argsort().to(torch.int32) - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -554,32 +517,51 @@ def _prepare_inputs( req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) - + # for pcp, prefill mtp should use origin scheduleroutput , + if self.speculative_config and self.pcp_size > 1: + self.pcp_manager.generate_pcp_mtp_input( + num_reqs, total_num_scheduled_tokens, + scheduler_output.num_scheduled_tokens, self.input_batch, + self.arange_np) req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) - _, arange = self._get_cumsum_and_arange(num_scheduled_tokens) - positions_np = np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - ) + # Get positions. + positions_np = self.positions.np[:total_num_scheduled_tokens] + cu_num_tokens, arange = self._get_cumsum_and_arange( + num_scheduled_tokens) + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) - total_num_pcp_pads = 0 if self.pcp_size > 1: if not self.vllm_config.model_config.use_mla: - self.generate_kv_idx(scheduler_output) - tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( - tokens) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - total_num_pcp_pads = torch.sum(self.num_pcp_pads).item() - else: - position_pcp, pcp_unpad_mask = None, None - self.num_pcp_pads = self.num_pcp_pads[:num_reqs] + self.pcp_manager.generate_kv_idx(scheduler_output, + self.input_batch) + num_scheduled_tokens[: + num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( + num_scheduled_tokens[:num_reqs], + self.arange_np, + self.input_batch.num_reqs, + self.reorder_batch_threshold, + ) + # Re-update after PCP split sequences. + total_num_scheduled_tokens = sum(num_scheduled_tokens) + scheduler_output.total_num_scheduled_tokens = total_num_scheduled_tokens + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange( + num_scheduled_tokens) + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + position_pcp[:total_num_scheduled_tokens], + out=positions_np, + ) max_num_scheduled_tokens = max(tokens) if not scheduler_output.scheduled_spec_decode_tokens: @@ -618,6 +600,7 @@ def _prepare_inputs( with_prefill = attn_state not in [ AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] + self.attn_mask = self._make_attention_mask(attn_state) self.query_lens = torch.from_numpy(num_scheduled_tokens) @@ -627,7 +610,7 @@ def _prepare_inputs( (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill) = self._sync_metadata_across_dp(num_input_tokens, with_prefill) - + self.with_prefill = with_prefill # TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens # We should consider removing maybe_padded_num_tokens later num_input_tokens = maybe_padded_num_tokens @@ -636,24 +619,6 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - if self.pcp_size > 1: - positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - position_pcp[:total_num_scheduled_tokens], - out=positions_np) - else: - self.positions.np[:total_num_scheduled_tokens] = positions_np - # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -746,14 +711,6 @@ def _prepare_inputs( cu_num_tokens) self.positions.cpu[total_num_scheduled_tokens:num_input_tokens].zero_() self.positions.copy_to_gpu() - - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) - self.attn_mask = self._make_attention_mask(attn_state) - self.attn_state = attn_state # type: ignore - - self.with_prefill = with_prefill - self.num_tokens_across_dp = num_tokens_across_dp attn_metadata: dict[str, Any] = {} # Record the index of requests that should not be sampled, @@ -882,9 +839,8 @@ def _prepare_inputs( # TODO: Support prompt logprobs. spec_decode_metadata = None if self.pcp_size * self.dcp_size > 1: - logits_indices = torch.from_numpy( - cu_num_tokens - ) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1 + logits_indices = self.pcp_manager.get_logits_indices( + cu_num_tokens, num_reqs) logits_indices = logits_indices.pin_memory().to( self.device, non_blocking=True) else: @@ -906,7 +862,10 @@ def _prepare_inputs( >= self.input_batch.num_prompt_tokens[req_idx]) else -1) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens, self.num_pcp_pads[:num_reqs]) + num_draft_tokens, + cu_num_tokens, + num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs] + if self.pcp_size > 1 else None) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). @@ -928,22 +887,10 @@ def _prepare_inputs( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - if self.speculative_config and self.pcp_size > 1: - self._generate_pcp_mtp_input( - num_reqs, scheduler_output.total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens) - - long_seq_metadata = self._generate_pcp_metadata( - total_num_scheduled_tokens) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - # NOTE: This is strange, why did we use total_num_scheduled_tokens before? - slot_mapping_size = (total_num_scheduled_tokens - if self.pcp_size == 1 else - total_num_scheduled_tokens * self.pcp_size - - total_num_pcp_pads) if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to @@ -959,30 +906,29 @@ def _prepare_inputs( device=self.device, ) else: + maybe_pcp_full_tokens = ( + num_input_tokens if self.pcp_size == 1 else + total_num_scheduled_tokens * self.pcp_size - + sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])) blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor() - blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(0) - if self.pcp_size > 1: - slot_mapping_for_pcp = blk_table.slot_mapping.gpu[: - long_seq_metadata - . - num_actual_tokens_pcp_padded] - slot_mapping_for_pcp[slot_mapping_size:].fill_(-1) - assert pcp_unpad_mask is not None - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[: - pcp_unpad_mask - . - shape[ - 0]] - pcp_padded_slot_mapping.fill_(-1) - pcp_padded_slot_mapping[ - pcp_unpad_mask] = slot_mapping_for_pcp[: - slot_mapping_size] - slot_mapping_for_pcp[:long_seq_metadata. - num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping - blk_table.slot_mapping.gpu[:long_seq_metadata.num_actual_tokens_pcp_padded] = \ - slot_mapping_for_pcp - slot_mapping = blk_table.slot_mapping.gpu + # blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(-1) + slot_mapping = blk_table.slot_mapping.gpu[: + maybe_pcp_full_tokens] + if self.pcp_size == 1: + slot_mapping[ + total_num_scheduled_tokens:num_input_tokens].fill_(-1) + # blk_table_tensor[num_reqs:].fill_(-1) + slot_mapping = blk_table.slot_mapping.gpu + if self.pcp_size > 1: + self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( + total_num_scheduled_tokens, self.query_lens, + self.attn_mask, self.input_batch) + slot_mapping = slot_mapping[:maybe_pcp_full_tokens] + slot_mapping = self.pcp_manager.get_padded_slot_mapping( + total_num_scheduled_tokens, + slot_mapping, + ) # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs # has been split to multiple parts, and there are 3 parts that is related to this @@ -1021,7 +967,7 @@ def _prepare_inputs( seq_lens_cpu=self.seq_lens.cpu[:num_reqs], seq_lens=self.seq_lens.gpu[:num_reqs], num_reqs=num_reqs, - num_actual_tokens=slot_mapping_size, + num_actual_tokens=total_num_scheduled_tokens, num_input_tokens=num_input_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, # TODO: change this to the right block table for linear attn @@ -1034,7 +980,7 @@ def _prepare_inputs( attn_state=self.attn_state, max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, - prefill_context_parallel_metadata=long_seq_metadata, + prefill_context_parallel_metadata=self.long_seq_metadata, ) if self.speculative_config and self.pcp_size > 1: @@ -1045,8 +991,8 @@ def _prepare_inputs( # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens = self.query_start_loc_pcp_full.cpu[1:num_reqs + 1] - \ - self.query_start_loc_pcp_full.cpu[:num_reqs] + ori_query_lens = self.pcp_manager.query_start_loc_pcp_full.cpu[1:num_reqs+1] - \ + self.pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs] num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() num_decode_reqs = num_reqs - num_prefill_reqs @@ -1161,15 +1107,8 @@ def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens, pad_size = get_forward_context().pad_size if pad_size > 0: hidden_states = hidden_states[:-pad_size, :] - - if self.pcp_size > 1: - hidden_states = get_pcp_group().all_gather( - hidden_states[:self.num_actual_tokens_pcp_padded // - self.pcp_size], 0) - hidden_states = torch.index_select( - hidden_states, 0, - self.pcp_allgather_restore_idx[:hidden_states.shape[0]]) - return hidden_states + return hidden_states if self.pcp_size == 1 else self.pcp_manager.get_restore_hidden_states( + hidden_states) def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): @@ -1206,7 +1145,7 @@ def _calc_spec_decode_metadata( self, num_draft_tokens: np.ndarray, cu_num_scheduled_tokens: np.ndarray, - num_pcp_pads: np.ndarray, + num_pcp_pads: np.ndarray | None, ) -> SpecDecodeMetadata: # Inputs: # cu_num_scheduled_tokens: [ 4, 104, 107, 207, 209] @@ -1823,7 +1762,9 @@ def _build_dummy_attn_metadata( self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) - long_seq_metadata = self._generate_pcp_metadata(num_tokens) + long_seq_metadata = None if self.pcp_size == 1 else self.pcp_manager.generate_pcp_metadata( + num_tokens, self.query_lens, self.attn_mask, + self.input_batch) if long_seq_metadata is not None: pcp_world_size = get_pcp_group().world_size dcp_world_size = get_dcp_group().world_size @@ -1867,19 +1808,36 @@ def _build_dummy_attn_metadata( self.speculative_config.method == "mtp": attn_state = AscendAttentionState.SpecDecoding - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + + if vllm_version_is("0.12.0"): + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - _seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.cpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - block_table_tensor=block_table_tensor[:num_reqs], - slot_mapping=slot_mapping.gpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, - max_query_len=max_query_len, - max_seq_len=seq_lens) + query_start_loc_cpu=self.query_start_loc. + cpu[:num_reqs + 1], + seq_lens_cpu=self.seq_lens.cpu[:num_reqs], + seq_lens=self.seq_lens.cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=slot_mapping.gpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + max_seq_len=seq_lens) + else: + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[:num_reqs + + 1], + query_start_loc_cpu=self.query_start_loc. + cpu[:num_reqs + 1], + _seq_lens_cpu=self.seq_lens.cpu[:num_reqs], + seq_lens=self.seq_lens.cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=slot_mapping.gpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + max_seq_len=seq_lens) for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() @@ -2177,7 +2135,10 @@ def profile_run(self) -> None: self._dummy_run(mc2_tokens_capacity, with_prefill=True, is_profile=True) + if self.pcp_size > 1: + self.mx_num_tokens = self.max_num_tokens // self.pcp_size super().profile_run() + self.mx_num_tokens = self.scheduler_config.max_num_batched_tokens def eplb_warmup(self): if self.dynamic_eplb and not self.is_eplb_warmuped: @@ -3023,300 +2984,6 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) - def _update_tokens_for_pcp(self, tokens): - num_reqs = self.input_batch.num_reqs - self.num_pcp_pads = self.num_pcp_pads[:num_reqs] - tokens = np.array(tokens, dtype=np.int32) - num_decode_reqs = sum( - self.input_batch.num_computed_tokens_cpu[:num_reqs] >= - self.input_batch.num_prompt_tokens[:num_reqs]) - num_decode_tokens = sum(tokens[:num_decode_reqs]) - num_padded_scheduled_tokens = np.ceil( - tokens / - (2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size) - num_padded_scheduled_tokens[:num_decode_reqs] = ( - tokens[:num_decode_reqs] * self.pcp_size) - self.num_pcp_pads = torch.tensor(num_padded_scheduled_tokens - tokens) - cu_padded_tokens, pcp_padded_arange = \ - self._get_cumsum_and_arange(num_padded_scheduled_tokens) - unpad_mask = torch.from_numpy( - pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens)) - unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size] - unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size]) - unpad_mask_decode[:, 0] = True - unpad_mask_decode[:, 1:] = False - - pcp_tokens = num_padded_scheduled_tokens // self.pcp_size - pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) - pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] - _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) - _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) - pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, - pcp_tokens) - - def get_current_rank_positions(cu_tokens, rank): - positions_start_loc = np.zeros_like(cu_tokens) - positions_start_loc[1:] = cu_tokens[:-1] - positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) - head_start_loc = positions_start_loc + rank * pcp_chunk_sizes - tail_start_loc = positions_start_loc + \ - (2 * self.pcp_size - rank - 1) * pcp_chunk_sizes - positions[pcp_head_chunk_mask] = pcp_chunk_arange + \ - np.repeat(head_start_loc, pcp_chunk_sizes) - # Decode reqs do not have tail chunks. - positions[~pcp_head_chunk_mask] = \ - pcp_chunk_arange[num_decode_tokens:] + \ - np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] - return positions - - positions = get_current_rank_positions( - np.zeros(num_reqs, dtype=np.int32), self.pcp_rank) - # Decode tokens are duplicate and their positions always be 0. - if num_decode_reqs > 0: - positions[:num_decode_tokens] = self._get_cumsum_and_arange( - tokens[:num_decode_reqs])[1] - - all_positions = [ - get_current_rank_positions(cu_padded_tokens, rank_i) - for rank_i in range(self.pcp_size) - ] - all_positions_tensor = torch.from_numpy(np.concatenate(all_positions)) - self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_( - all_positions_tensor.float().argsort().long(), non_blocking=True) - return pcp_tokens, positions, unpad_mask - - def _get_cp_local_seq_lens( - self, - seq_lens: torch.Tensor, - pcp_world_size: int = 1, - dcp_world_size: int = 1, - cp_kv_cache_interleave_size: int = 1, - ) -> torch.Tensor: - """While using pcp or dcp, kv_cache size stored on each rank may be different, - use this function to calculate split decode seq_lens of each (p/d)cp rank. - """ - num_requests = seq_lens.size(0) - total_world_size = pcp_world_size * dcp_world_size - seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) - rank_offsets = (torch.arange(total_world_size, - dtype=torch.int32).unsqueeze(0).repeat( - num_requests, 1)) - base = (seq_lens_tiled // cp_kv_cache_interleave_size // - total_world_size * cp_kv_cache_interleave_size) - remainder = seq_lens_tiled - base * total_world_size - remainder = torch.clip( - remainder - rank_offsets * cp_kv_cache_interleave_size, - 0, - cp_kv_cache_interleave_size, - ) - dcp_local_seq_lens = (base + remainder).reshape( - [-1, pcp_world_size, dcp_world_size]) - return dcp_local_seq_lens - - def _generate_pcp_metadata(self, total_num_scheduled_tokens): - # In dummy run num_reqs == 0, update it from seq_lens - num_reqs = self.input_batch.num_reqs or self.query_lens.size(0) - num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] - >= self.input_batch.num_prompt_tokens[:num_reqs]) - num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size - self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded - long_seq_metadata = None - if self.pcp_size * self.dcp_size > 1: - decode_context_lens = self.input_batch.num_tokens[:num_decodes] - prefill_context_lens = self.input_batch.num_computed_tokens_cpu[ - num_decodes:num_reqs] - context_lens = np.concatenate( - [decode_context_lens, prefill_context_lens]) - num_computed_tokens_of_pcp_dcp = torch.zeros( - [ - num_reqs * self.decode_threshold, self.pcp_size, - self.dcp_size - ], - dtype=torch.int32, - ) - # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape - for decode_idx in range(self.decode_threshold): - num_computed_tokens_of_pcp_dcp[ - self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ - self._get_cp_local_seq_lens( - torch.tensor(context_lens), - self.pcp_size, - self.dcp_size, - self.parallel_config.cp_kv_cache_interleave_size, - ) - long_seq_metadata = AscendPrefillContextParallelMetadata( - num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, - num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. - numpy()) - if self.pcp_size > 1: - q_head_idx, q_tail_idx = [], [] - kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] - kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - chunk_seqlens = [] - kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] - q_req_offset = 0 - kv_req_offset = 0 - q_head_chunk_id = self.pcp_rank - q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank - for i, seq_len in enumerate(self.query_lens): - if i < num_decodes: - continue - chunk_len = seq_len // 2 - chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) - kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_head_chunk_id))) - kv_with_q_head_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) - - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) - kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_tail_chunk_id))) - kv_with_q_tail_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) - - q_req_offset += seq_len - kv_req_offset += seq_len * self.pcp_size - - # Convert lists to tensors and move to device - def _list_to_tensor(lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), - dtype=dtype, - device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), - non_blocking=True) - return tensor_npu - - q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) - q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) - self.q_head_idx_tensor = q_head_idx_tensor - self.q_tail_idx_tensor = q_tail_idx_tensor - - q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to( - torch.int32) - self.q_full_idx = q_full_idx - - self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': - kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': - kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx - } - for key, value in self.kv_idx_names.items(): - tensor_npu = _list_to_tensor(value, self.device) - self.kv_idx_names[key] = tensor_npu - - attn_mask_seqlens = torch.tensor( - [chunk_seqlens, chunk_seqlens], dtype=torch.int32) - head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) - tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) - pcp_prefill_mask = self.attn_mask - - self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'pcp_prefill_mask': pcp_prefill_mask - } - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[: - num_actual_tokens_pcp_padded] - long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk - long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor - long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor - long_seq_metadata.q_full_idx = self.q_full_idx - long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_mask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_mask_idx_tensor'] - long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ - 'attn_mask_seqlens'] - long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'head_attn_nomask_seqlens'] - long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'tail_attn_nomask_seqlens'] - long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ - 'pcp_prefill_mask'] - self.long_seq_metadata = long_seq_metadata - return long_seq_metadata - - def _generate_pcp_mtp_input( - self, - num_reqs: int, - total_num_scheduled_tokens: int, - num_scheduled_tokens: dict[str, int], - ): - """ - While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, - but mtp need to shift original input_ids before pcp splitting, - so we record original input_ids here. - """ - total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens - num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) - for i, req_id in enumerate(self.input_batch.req_ids): - num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] - req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens_pcp_full) - cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) - self.query_start_loc_pcp_full.np[0] = 0 - self.query_start_loc_pcp_full.np[1:num_reqs + - 1] = cu_num_tokens_pcp_full - self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) - cumsums_offsets_pcp_full = np.repeat( - cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, - num_scheduled_tokens_pcp_full) - arange_pcp_full = self.arange_np[: - total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full - positions_pcp_full_np = self.positions_pcp_full_np[: - total_num_scheduled_tokens_pcp_full] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], - arange_pcp_full, - out=positions_pcp_full_np) - token_indices_pcp_full = ( - positions_pcp_full_np + - req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices_pcp_full), - out=self.input_ids_pcp_full. - cpu[:total_num_scheduled_tokens_pcp_full]) - self.query_start_loc_pcp_full.copy_to_gpu() - self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_( - self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full], - non_blocking=True, - ) - @contextmanager def _torch_cuda_wrapper(): From d1b87a4393de1b29e7511e53a738ecdc0532cfde Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 20 Dec 2025 15:51:09 +0800 Subject: [PATCH 11/43] [Feature] refactor model_runner for pcp & dcp Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 71 +++++++++------------------ 1 file changed, 22 insertions(+), 49 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f7920674094..7e1b42bc0f8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -116,8 +116,7 @@ from vllm_ascend.utils import (AscendDeviceType, PCPManager, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_moe_model, - lmhead_tp_enable, maybe_trans_nz, - vllm_version_is) + lmhead_tp_enable, maybe_trans_nz) from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.ascend_forward_context import ( # isort: skip @@ -264,24 +263,15 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # Set up Attention self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, "index_topk") - if vllm_version_is('0.12.0'): - self.attn_backend = get_attn_backend( - 0, - self.dtype, - None, - self.block_size, - use_mla=self.model_config.use_mla, - use_sparse=self.use_sparse) - else: - self.attn_backend = get_attn_backend( - 0, - self.dtype, - None, - self.block_size, - use_mla=self.model_config.use_mla, - use_sparse=self.use_sparse, - use_mm_prefix=self.model_config is not None - and self.model_config.is_mm_prefix_lm) + self.attn_backend = get_attn_backend( + 0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse, + use_mm_prefix=self.model_config is not None + and self.model_config.is_mm_prefix_lm) self.attn_mask_builder = AttentionMaskBuilder(self.device) self._set_up_drafter() @@ -1808,36 +1798,19 @@ def _build_dummy_attn_metadata( self.speculative_config.method == "mtp": attn_state = AscendAttentionState.SpecDecoding - if vllm_version_is("0.12.0"): - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + - 1], - query_start_loc_cpu=self.query_start_loc. - cpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.cpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - block_table_tensor=block_table_tensor[:num_reqs], - slot_mapping=slot_mapping.gpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, - max_query_len=max_query_len, - max_seq_len=seq_lens) - else: - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc. - cpu[:num_reqs + 1], - _seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.cpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - block_table_tensor=block_table_tensor[:num_reqs], - slot_mapping=slot_mapping.gpu, - _num_computed_tokens_cpu=num_computed_tokens_cpu, - max_query_len=max_query_len, - max_seq_len=seq_lens) + _seq_lens_cpu=self.seq_lens.cpu[:num_reqs], + seq_lens=self.seq_lens.cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=slot_mapping.gpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + max_seq_len=seq_lens) for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() From b072357847b127fbc0d2137e83ff09ccbee480d4 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 20 Dec 2025 15:55:53 +0800 Subject: [PATCH 12/43] [Feature] refactor model_runner for pcp & dcp Signed-off-by: zhenwenqi2024 --- vllm_ascend/utils.py | 3 --- vllm_ascend/worker/model_runner_v1.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index de8d9476b84..e6591300e36 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -33,9 +33,7 @@ from torch_npu.npu.streams import Event from vllm.logger import logger from vllm.sequence import IntermediateTensors -# from vllm.distributed.parallel_state import get_pcp_group from vllm.utils.math_utils import cdiv -# from vllm.config import VllmConfig from vllm.v1.utils import CpuGpuBuffer import vllm_ascend.envs as envs_ascend @@ -1109,7 +1107,6 @@ def __init__( dcp_rank: int, max_buffer_num_tokens: int, max_num_reqs: int, - # speculative_config, device: torch.device, vllm_config: VllmConfig, pin_memory: bool = False, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7e1b42bc0f8..58f3212c725 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -211,7 +211,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): if self.pcp_size > 1: self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs max_buffer_num_tokens = self.max_num_tokens - logger.info(f"self.max_num_tokens is {self.max_num_tokens}") if self.pcp_size > 1: max_buffer_num_tokens = (self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size) @@ -231,6 +230,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int32) self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) + self.long_seq_metadata = None if envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP: self.prefetch_stream = torch.npu.Stream(device=device) else: From 5ec1942018652f7c0ff05a3d6a1474e1e71c9207 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 20 Dec 2025 16:43:45 +0800 Subject: [PATCH 13/43] [Feature] refactor model_runner for pcp & dcp Signed-off-by: zhenwenqi2024 --- vllm_ascend/utils.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e6591300e36..6cd480ce81a 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1341,7 +1341,6 @@ def get_current_rank_positions(positions_start_loc: int | np.ndarray, for rank_i in range(self.pcp_world_size) ] all_positions = np.concatenate(all_positions_lst) - logger.info(f"all_positions is {all_positions}") self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = ( all_positions.argsort()) self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) @@ -1370,7 +1369,6 @@ def get_padded_slot_mapping(self, num_tokens: list, slot_mapping: torch.Tensor): # After pcp allgather and restore, there are padded tokens in kv, # so we need pad slotmapping for alignment. - # print(f"num_actual_tokens_pcp_padded: {self.num_actual_tokens_pcp_padded}") pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens * self. pcp_world_size] @@ -1510,16 +1508,10 @@ def generate_kv_idx(self, scheduler_output, input_batch): cp_kv_recover_idx_for_chunk = torch.from_numpy( np.concatenate(self.cp_kv_recover_idx_for_chunk)).to( device=self.device) - logger.info( - f"cp_kv_recover_idx_for_chunk is {cp_kv_recover_idx_for_chunk}" - ) cp_kv_recover_idx_for_chunk.copy_(torch.tensor( np.array( self.cp_kv_recover_idx_for_chunk).flatten().tolist()), non_blocking=True) - logger.info( - f"cp_kv_recover_idx_for_chunk22222 is {cp_kv_recover_idx_for_chunk}" - ) self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( torch.float32).argsort().to(torch.int32) @@ -1527,7 +1519,7 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, attn_mask, input_batch): from vllm_ascend.attention.utils import \ AscendPrefillContextParallelMetadata - num_reqs = input_batch.num_reqs or self.query_lens.size(0) + num_reqs = input_batch.num_reqs or query_lens.size(0) num_decodes = sum(input_batch.num_computed_tokens_cpu[:num_reqs] >= input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size From daea466c88044fa072202a852d4c16249448214b Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sat, 20 Dec 2025 17:22:22 +0800 Subject: [PATCH 14/43] [Feature] refactor model_runner for pcp & dcp Signed-off-by: zhenwenqi2024 --- vllm_ascend/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 6cd480ce81a..8035e87a9ff 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1365,7 +1365,7 @@ def get_discard_request_mask( num_scheduled_tokens * self.pcp_world_size - self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np - def get_padded_slot_mapping(self, num_tokens: list, + def get_padded_slot_mapping(self, num_tokens: int, slot_mapping: torch.Tensor): # After pcp allgather and restore, there are padded tokens in kv, # so we need pad slotmapping for alignment. From 72f11e552b58a615fab07cd6a10ba15b7afc3f1c Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 17:31:54 +0800 Subject: [PATCH 15/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/utils.py | 5 ++--- vllm_ascend/worker/model_runner_v1.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 8035e87a9ff..abe8004c011 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1129,8 +1129,8 @@ def __init__( device=device, pin_memory=pin_memory, ) - self.pcp_padded_slot_mapping = torch.empty( - (max_buffer_num_tokens, ), + self.pcp_padded_slot_mapping = torch.full( + (max_buffer_num_tokens, ), fill_value=-1, dtype=torch.int32, device=device, ) @@ -1374,7 +1374,6 @@ def get_padded_slot_mapping(self, num_tokens: int, pcp_world_size] cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * self.pcp_world_size] - pcp_padded_slot_mapping.fill_(-1) pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping return pcp_padded_slot_mapping diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 58f3212c725..c043dc59cb1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -919,6 +919,9 @@ def _prepare_inputs( total_num_scheduled_tokens, slot_mapping, ) + blk_table.slot_mapping.gpu[:self.pcp_manager.num_actual_tokens_pcp_padded] = slot_mapping + + # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs # has been split to multiple parts, and there are 3 parts that is related to this From b8b906838408474413f9c2e83eede1f77c572fa5 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 19:57:54 +0800 Subject: [PATCH 16/43] cleancode Signed-off-by: zhenwenqi2024 --- tests/ut/worker/test_model_runner_v1.py | 473 ----------------- tests/ut/worker/test_pcp_manager.py | 320 ++++++++++++ vllm_ascend/utils.py | 589 +-------------------- vllm_ascend/worker/model_runner_v1.py | 514 +++---------------- vllm_ascend/worker/pcp_utils.py | 654 ++++++++++++++++++++++++ 5 files changed, 1036 insertions(+), 1514 deletions(-) delete mode 100644 tests/ut/worker/test_model_runner_v1.py create mode 100644 tests/ut/worker/test_pcp_manager.py create mode 100644 vllm_ascend/worker/pcp_utils.py diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py deleted file mode 100644 index 8ff26a6f9fe..00000000000 --- a/tests/ut/worker/test_model_runner_v1.py +++ /dev/null @@ -1,473 +0,0 @@ -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -from unittest.mock import MagicMock - -import numpy as np -import pytest -import torch - -from vllm_ascend.worker.model_runner_v1 import NPUModelRunner - - -@pytest.mark.parametrize( - "pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none", - [ - (1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False), - (1, 2, 3, [20, 30, 40], 1, False, 50, True), - (2, 1, 4, [5, 10, 40, 60], 2, False, 100, True), - (2, 1, 4, [5, 10, 40, 60], 2, True, 100, True), - (2, 1, 3, [5, 10, 15], 3, False, 50, True), - (2, 1, 3, [40, 50, 60], 0, False, 150, True), - ]) -def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, - num_decodes, use_mla, total_tokens, - expect_not_none): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = pcp_size - mock_runner.dcp_size = dcp_size - mock_runner.decode_threshold = 4 - mock_runner.pcp_rank = 0 - mock_runner.device = torch.device('cpu') - mock_runner.dtype = torch.float32 - - mock_runner.parallel_config = MagicMock() - mock_runner.parallel_config.cp_kv_cache_interleave_size = 64 - - mock_runner.vllm_config = MagicMock() - mock_runner.vllm_config.model_config = MagicMock() - mock_runner.vllm_config.model_config.use_mla = use_mla - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = num_reqs - mock_runner.speculative_config = None - - num_computed_tokens = [] - num_prompt_tokens = [] - num_tokens = [] - - for i in range(num_reqs): - if i < num_decodes: - num_computed_tokens.append(query_lens[i]) - num_prompt_tokens.append(query_lens[i] // 2) - num_tokens.append(query_lens[i]) - else: - num_computed_tokens.append(0) - num_prompt_tokens.append(query_lens[i]) - num_tokens.append(query_lens[i]) - - mock_runner.input_batch.num_computed_tokens_cpu = torch.tensor( - num_computed_tokens) - mock_runner.input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) - mock_runner.input_batch.num_tokens = torch.tensor(num_tokens) - - mock_runner.query_lens = torch.tensor(query_lens) - - mock_runner._get_cp_local_seq_lens = NPUModelRunner._get_cp_local_seq_lens.__get__( - mock_runner, NPUModelRunner) - - mock_runner.pcp_allgather_restore_idx = torch.arange(total_tokens * 2) - mock_runner.cp_kv_recover_idx_for_chunk = torch.arange(total_tokens) - - mock_runner.long_seq_metadata = None - mock_runner.num_actual_tokens_pcp_padded = 0 - mock_runner.kv_idx_names = {} - mock_runner.extra_long_seq_kwargs = {} - mock_runner.attn_mask = None - mock_runner.q_head_idx_tensor = None - mock_runner.q_tail_idx_tensor = None - mock_runner.q_full_idx = None - - method = NPUModelRunner._generate_pcp_metadata.__get__( - mock_runner, NPUModelRunner) - result = method(total_tokens) - - if not expect_not_none: - assert result is None, f"Expected to return None, but got {type(result)}" - else: - assert result is not None, "Expected to return a metadata object, but got None." - - assert hasattr(result, 'num_actual_tokens_pcp_padded') - assert hasattr(result, 'num_computed_tokens_of_pcp_dcp') - - if pcp_size > 1: - assert hasattr(result, 'pcp_allgather_restore_idx') - - has_prefill_requests = (num_reqs - num_decodes) > 0 - if has_prefill_requests: - assert hasattr(result, 'q_head_idx_tensor') - assert hasattr(result, 'q_tail_idx_tensor') - assert hasattr(result, 'q_full_idx') - assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor') - assert hasattr(result, 'kv_with_q_head_mask_idx_tensor') - assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor') - assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor') - assert hasattr(result, 'attn_mask_seqlens') - assert hasattr(result, 'head_attn_nomask_seqlens') - assert hasattr(result, 'tail_attn_nomask_seqlens') - - if hasattr(result, 'pcp_prefill_mask' - ) and result.pcp_prefill_mask is not None: - if use_mla: - assert result.pcp_prefill_mask.shape == (512, 512) - else: - assert result.pcp_prefill_mask.shape == (2048, 2048) - else: - if hasattr(result, 'pcp_prefill_mask'): - if result.pcp_prefill_mask is not None: - if use_mla: - assert result.pcp_prefill_mask.shape == (512, 512) - else: - assert result.pcp_prefill_mask.shape == (2048, - 2048) - - -def test_generate_pcp_metadata_edge_cases(): - mock_runner = MagicMock() - mock_runner.pcp_size = 2 - mock_runner.dcp_size = 1 - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = 0 - mock_runner.query_lens = torch.tensor([10, 20, 30]) - - assert (mock_runner.input_batch.num_reqs - or mock_runner.query_lens.size(0)) == 3 - - mock_runner.input_batch.num_reqs = 100 - mock_runner.query_lens = torch.ones(100) * 1000 - - for rank in [0, 1]: - mock_runner.pcp_rank = rank - q_head_chunk_id = rank - q_tail_chunk_id = 2 * 2 - 1 - rank - assert q_head_chunk_id == rank - assert q_tail_chunk_id == 3 - rank - - -def test_pcp_allgather_restore_idx_slicing(): - mock_runner = MagicMock() - mock_runner.pcp_size = 2 - mock_runner.pcp_allgather_restore_idx = torch.arange(1000) - - total_num_scheduled_tokens = 200 - num_actual_tokens_pcp_padded = total_num_scheduled_tokens * 2 - - expected_slice = mock_runner.pcp_allgather_restore_idx[: - num_actual_tokens_pcp_padded] - assert len(expected_slice) == 400 - assert expected_slice[0] == 0 - assert expected_slice[-1] == 399 - - -@pytest.mark.parametrize( - "tokens, num_reqs, num_computed_tokens, num_prompt_tokens," \ - "pcp_size, pcp_rank, decode_threshold, expected_pcp_tokens", - [ - # Case 1: prefill only - ([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, 1, [2, 4, 4]), - - # Case 2: mix prefill and decode (with spec decode) - ([8, 4, 12], 3, [8, 4, 0], [8, 4, 12], 4, 0, 8, [8, 4, 4]), - - # Case 3: request which need to be padded - ([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, 1, [2, 2, 4]), - - # Case 4: single request - ([10], 1, [0], [10], 4, 0, 1, [4]), - ]) -def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, - num_prompt_tokens, pcp_size, pcp_rank, - decode_threshold, expected_pcp_tokens): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = pcp_size - mock_runner.pcp_rank = pcp_rank - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = num_reqs - mock_runner.input_batch.num_computed_tokens_cpu = np.array( - num_computed_tokens, dtype=np.int32) - mock_runner.input_batch.num_prompt_tokens = np.array(num_prompt_tokens, - dtype=np.int32) - - mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) - - mock_runner.num_pcp_pads = [0] * num_reqs - mock_runner.arange_np = np.arange(10000) - mock_runner.decode_threshold = decode_threshold - - mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( - mock_runner, NPUModelRunner) - mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( - mock_runner, NPUModelRunner) - - pcp_tokens_result, positions_result, unpad_mask_result = mock_runner._update_tokens_for_pcp( - tokens) - - assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \ - f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}" - - total_pcp_tokens: int = np.sum(pcp_tokens_result) - assert positions_result.shape == (total_pcp_tokens,), \ - f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}" - - padded_tokens = [ - (t + 2 * pcp_size - 1) // (2 * pcp_size) * - (2 * pcp_size) if num_computed_tokens[i] == 0 else t * pcp_size - for i, t in enumerate(tokens) - ] - total_padded_tokens: int = np.sum(padded_tokens) - assert unpad_mask_result.shape[0] == total_padded_tokens, \ - f"unpad_mask size mismatch: expected {total_padded_tokens}, got {unpad_mask_result.shape[0]}" - - -def test_update_tokens_for_pcp_with_padding(): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = 4 - mock_runner.pcp_rank = 0 - - mock_runner.arange_np = np.arange(10000) - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = 3 - mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0, 0], - dtype=np.int32) - mock_runner.input_batch.num_prompt_tokens = np.array([5, 9, 13], - dtype=np.int32) - - mock_runner.num_pcp_pads = [0, 0, 0] - mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) - mock_runner.decode_threshold = 1 - - mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( - mock_runner, NPUModelRunner) - mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( - mock_runner, NPUModelRunner) - - tokens = [5, 9, 13] - - pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp( - tokens) - - expected_pcp_tokens = [2, 4, 4] - assert np.array_equal(pcp_tokens, expected_pcp_tokens), \ - f"Expected {expected_pcp_tokens}, got {pcp_tokens}" - - expected_pads = [3, 7, 3] - assert np.array_equal(mock_runner.num_pcp_pads, expected_pads), \ - f"Expected padding {expected_pads}, got {mock_runner.num_pcp_pads}" - - -def test_update_tokens_for_pcp_unpad_mask(): - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.pcp_size = 4 - mock_runner.pcp_rank = 0 - - mock_runner.arange_np = np.arange(10000) - - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_reqs = 2 - mock_runner.input_batch.num_computed_tokens_cpu = np.array([0, 0], - dtype=np.int32) - mock_runner.input_batch.num_prompt_tokens = np.array([5, 7], - dtype=np.int32) - - mock_runner.num_pcp_pads = [0, 0] - mock_runner.pcp_allgather_restore_idx = torch.zeros(1000, dtype=torch.long) - mock_runner.decode_threshold = 1 - - mock_runner._update_tokens_for_pcp = NPUModelRunner._update_tokens_for_pcp.__get__( - mock_runner, NPUModelRunner) - mock_runner._get_cumsum_and_arange = NPUModelRunner._get_cumsum_and_arange.__get__( - mock_runner, NPUModelRunner) - - tokens = [5, 7] - - pcp_tokens, positions, unpad_mask = mock_runner._update_tokens_for_pcp( - tokens) - - assert unpad_mask.dtype == torch.bool, \ - f"unpad_mask should be bool, got {unpad_mask.dtype}" - - padded_tokens = [8, 8] - expected_length = sum(padded_tokens) - assert unpad_mask.shape[0] == expected_length, \ - f"unpad_mask length mismatch: expected {expected_length}, got {unpad_mask.shape[0]}" - - expected_mask = [True] * 5 + [False] * 3 + [True] * 7 + [False] * 1 - actual_mask = unpad_mask.numpy().tolist() - assert actual_mask == expected_mask, \ - f"unpad_mask incorrect. Expected {expected_mask}, got {actual_mask}" - - -# yapf: disable -@pytest.mark.parametrize( - "seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target", - [ - # without pcp and dcp - (torch.tensor([1, 2, 128, 129]), 1, 1, 1, - torch.tensor([[[1]], [[2]], [[128]], [[129]]])), - # pcp - (torch.tensor([1, 2, 128, 129]), 2, 1, 1, - torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])), - # dcp - (torch.tensor([1, 2, 128, 129]), 1, 2, 1, - torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])), - # pcp + dcp - (torch.tensor([1, 2, 128, 129]), 2, 2, 1, - torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]], - [[32, 32], [32, 32]], [[33, 32], [32, 32]]])), - # specify interleave_size - (torch.tensor([1, 2, 128, 129]), 2, 1, 2, - torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])), - (torch.tensor([1, 2, 128, 129]), 2, 1, 128, - torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])), - (torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128, - torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]], - [[128, 0], [0, 0]], [[128, 1], [0, 0]], - [[128, 128], [0, 0]], [[128, 128], [1, 0]]])), - ] -) -# yapf: enable -def test_get_cp_local_seq_lens( - seq_lens, - pcp_world_size, - dcp_world_size, - cp_kv_cache_interleave_size, - target, -): - mock_runner = MagicMock(spec=NPUModelRunner) - ret = NPUModelRunner._get_cp_local_seq_lens(mock_runner, seq_lens, - pcp_world_size, dcp_world_size, - cp_kv_cache_interleave_size) - assert torch.equal(ret, target) - - -@pytest.fixture -def pcp_mtp_mock_runner(): - # set up pcp & mtp related buffers - max_num_reqs = 4 - max_model_len = 4096 - max_num_tokens = 4096 - mock_runner = MagicMock(spec=NPUModelRunner) - mock_runner.device = 'cpu' - mock_runner.pin_memory = False - - # Init model_runner pcp_mtp related buffers - mock_runner.query_start_loc_pcp_full = NPUModelRunner._make_buffer( - mock_runner, max_num_reqs + 1, dtype=torch.int32) - - positions_buff = torch.zeros(max_num_tokens, - dtype=torch.int64, - device="cpu") - mock_runner.positions_pcp_full = positions_buff - mock_runner.positions_pcp_full_np = positions_buff.numpy() - - mock_runner.input_ids_pcp_full = NPUModelRunner._make_buffer( - mock_runner, max_num_tokens, dtype=torch.int32) - mock_runner.query_lens_pcp_full = NPUModelRunner._make_buffer( - mock_runner, max_num_reqs, dtype=torch.int32) - mock_runner.decode_threshold = 1 - - mock_runner.arange_np = np.arange(max_model_len) - mock_runner.input_batch = MagicMock() - mock_runner.input_batch.num_computed_tokens_cpu = \ - np.zeros(max_num_reqs, dtype=np.int32) - token_ids_cpu_tensor = torch.zeros( - (max_num_reqs, max_model_len), - device="cpu", - dtype=torch.int32, - ) - mock_runner.input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor - mock_runner.input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy() - return mock_runner - - -# yapf: disable -@pytest.mark.parametrize( - "req_ids, num_computed_tokens," \ - "token_ids_tensor_list," \ - "num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \ - "target_input_ids_pcp_full, target_query_start_loc_pcp_full", - [ - # prefill - ( - ['0'], np.array([0]), - [torch.tensor([0, 671, 6102, 294, 8760, 344])], - 1, 6, {'0': 6}, - torch.tensor([0, 671, 6102, 294, 8760, 344]), - torch.tensor([0, 6]) - ), - # decode - ( - ['0'], np.array([6]), - [torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])], - 1, 2, {'0': 2}, - torch.tensor([88907, 0]), - torch.tensor([0, 2]) - ), - # decode + prefill - ( - ['0', '1'], np.array([6, 0]), - [ - torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), - torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), - ], - 2, 12, {'0': 2, '1': 10}, - torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), - torch.tensor([0, 2, 12]) - ), - # decodes + prefills - ( - ['0', '1', '2', '3'], np.array([6, 8, 0, 0]), - [ - torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), - torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]), - torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]), - torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]), - ], - 4, 19, {'0': 2, '1': 2, '2': 8, '3': 7}, - torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907, - 0, 671, 5335, 1469, 7539, 305, 6397]), - torch.tensor([0, 2, 4, 12, 19]) - ), - ]) -# yapf: enable -def test_generate_pcp_mtp_input( - pcp_mtp_mock_runner, - req_ids, - num_computed_tokens, - token_ids_tensor_list, - num_reqs, - total_num_scheduled_tokens, - num_scheduled_tokens, - target_input_ids_pcp_full, - target_query_start_loc_pcp_full, -): - mock_runner = pcp_mtp_mock_runner - token_ids_cpu_tensor = mock_runner.input_batch.token_ids_cpu_tensor - - # Set input_batch - mock_runner.input_batch.req_ids = req_ids - mock_runner.input_batch.num_computed_tokens_cpu[:num_computed_tokens. - size] = num_computed_tokens - for i, token_ids_tensor in enumerate(token_ids_tensor_list): - token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor - - NPUModelRunner._generate_pcp_mtp_input(mock_runner, num_reqs, - total_num_scheduled_tokens, - num_scheduled_tokens) - assert torch.equal( - mock_runner.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], - target_input_ids_pcp_full) - assert torch.equal(mock_runner.query_start_loc_pcp_full.cpu[:num_reqs + 1], - target_query_start_loc_pcp_full) diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py new file mode 100644 index 00000000000..9ea5220d749 --- /dev/null +++ b/tests/ut/worker/test_pcp_manager.py @@ -0,0 +1,320 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch + +from vllm_ascend.worker.pcp_utils import PCPManager + + +@pytest.mark.parametrize( + "pcp_size, dcp_size, num_reqs, query_lens, num_decodes, use_mla, total_tokens, expect_not_none", + [ + (1, 1, 5, [10, 20, 30, 40, 50], 2, False, 100, False), + (1, 2, 3, [20, 30, 40], 1, False, 50, True), + (2, 1, 4, [5, 10, 40, 60], 2, False, 100, True), + (2, 1, 4, [5, 10, 40, 60], 2, True, 100, True), + (2, 1, 3, [5, 10, 15], 3, False, 50, True), + (2, 1, 3, [40, 50, 60], 0, False, 150, True), + ]) +def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, + num_decodes, use_mla, total_tokens, + expect_not_none): + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.model_config.use_mla = use_mla + vllm_config.parallel_config.cp_kv_cache_interleave_size = 64 + vllm_config.speculative_config.num_speculative_tokens=0 + + + pcp_manager = PCPManager(pcp_world_size=pcp_size, + pcp_rank=0, + dcp_world_size=dcp_size, + dcp_rank=0, + max_buffer_num_tokens=10000, + max_num_reqs=1000, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + input_batch = MagicMock() + input_batch.num_reqs = num_reqs + + num_computed_tokens = [] + num_prompt_tokens = [] + num_tokens = [] + + for i in range(num_reqs): + if i < num_decodes: + num_computed_tokens.append(query_lens[i]) + num_prompt_tokens.append(query_lens[i] // 2) + num_tokens.append(query_lens[i]) + else: + num_computed_tokens.append(0) + num_prompt_tokens.append(query_lens[i]) + num_tokens.append(query_lens[i]) + + input_batch.num_computed_tokens_cpu = torch.tensor( + num_computed_tokens) + input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) + input_batch.num_tokens = torch.tensor(num_tokens) + + query_lens = torch.tensor(query_lens) + result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None, input_batch) + + if not expect_not_none: + assert result is None, f"Expected to return None, but got {type(result)}" + else: + assert result is not None, "Expected to return a metadata object, but got None." + + assert hasattr(result, 'num_actual_tokens_pcp_padded') + assert hasattr(result, 'num_computed_tokens_of_pcp_dcp') + + if pcp_size > 1: + assert hasattr(result, 'pcp_allgather_restore_idx') + + has_prefill_requests = (num_reqs - num_decodes) > 0 + if has_prefill_requests: + assert hasattr(result, 'q_head_idx_tensor') + assert hasattr(result, 'q_tail_idx_tensor') + assert hasattr(result, 'q_full_idx') + assert hasattr(result, 'kv_with_q_head_nomask_idx_tensor') + assert hasattr(result, 'kv_with_q_head_mask_idx_tensor') + assert hasattr(result, 'kv_with_q_tail_nomask_idx_tensor') + assert hasattr(result, 'kv_with_q_tail_mask_idx_tensor') + assert hasattr(result, 'attn_mask_seqlens') + assert hasattr(result, 'head_attn_nomask_seqlens') + assert hasattr(result, 'tail_attn_nomask_seqlens') + + if hasattr(result, 'pcp_prefill_mask' + ) and result.pcp_prefill_mask is not None: + if use_mla: + assert result.pcp_prefill_mask.shape == (512, 512) + else: + assert result.pcp_prefill_mask.shape == (2048, 2048) + else: + if hasattr(result, 'pcp_prefill_mask'): + if result.pcp_prefill_mask is not None: + if use_mla: + assert result.pcp_prefill_mask.shape == (512, 512) + else: + assert result.pcp_prefill_mask.shape == (2048, + 2048) + +@pytest.mark.parametrize( + "tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens", + [ + # Case 1: prefill only + ([8, 12, 16], 3, [0, 0, 0], [8, 12, 16], 4, 0, [2, 4, 4]), + + # # Case 2: mix prefill and decode + ([8, 4, 12], 3, [8, 4, 0], [8, 0, 12], 4, 0, [2, 2, 4]), + + # # Case 3: request which need to be padded + ([3, 7, 9], 3, [0, 0, 0], [3, 7, 9], 4, 0, [2, 2, 4]), + + # Case 4: single request + ([10], 1, [0], [10], 4, 0, [4]), + ]) +def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, + num_prompt_tokens, pcp_size, pcp_rank, + expected_pcp_tokens): + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.speculative_config.num_speculative_tokens=0 + + + pcp_manager = PCPManager(pcp_world_size=pcp_size, + pcp_rank=0, + dcp_world_size=1, + dcp_rank=0, + max_buffer_num_tokens=10000, + max_num_reqs=1000, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + input_batch = MagicMock() + input_batch.num_reqs = num_reqs + input_batch.num_computed_tokens_cpu = np.array( + num_computed_tokens, dtype=np.int32) + input_batch.num_prompt_tokens = np.array(num_prompt_tokens, + dtype=np.int32) + arange_np = np.arange(10000) + pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(np.array(tokens), arange_np, num_reqs, 1) + + assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \ + f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}" + + total_pcp_tokens: int = np.sum(pcp_tokens_result) + assert positions_result.shape == (total_pcp_tokens,), \ + f"Positions shape mismatch. Expected length {total_pcp_tokens}, got {positions_result.shape}" + + +# yapf: disable +@pytest.mark.parametrize( + "seq_lens, pcp_world_size, dcp_world_size, cp_kv_cache_interleave_size, target", + [ + # without pcp and dcp + (torch.tensor([1, 2, 128, 129]), 1, 1, 1, + torch.tensor([[[1]], [[2]], [[128]], [[129]]])), + # pcp + (torch.tensor([1, 2, 128, 129]), 2, 1, 1, + torch.tensor([[[1], [0]], [[1], [1]], [[64], [64]], [[65], [64]]])), + # dcp + (torch.tensor([1, 2, 128, 129]), 1, 2, 1, + torch.tensor([[[1, 0]], [[1, 1]], [[64, 64]], [[65, 64]]])), + # pcp + dcp + (torch.tensor([1, 2, 128, 129]), 2, 2, 1, + torch.tensor([[[1, 0], [0, 0]], [[1, 1], [0, 0]], + [[32, 32], [32, 32]], [[33, 32], [32, 32]]])), + # specify interleave_size + (torch.tensor([1, 2, 128, 129]), 2, 1, 2, + torch.tensor([[[1], [0]], [[2], [0]], [[64], [64]], [[65], [64]]])), + (torch.tensor([1, 2, 128, 129]), 2, 1, 128, + torch.tensor([[[1], [0]], [[2], [0]], [[128], [0]], [[128], [1]]])), + (torch.tensor([1, 2, 128, 129, 256, 257]), 2, 2, 128, + torch.tensor([[[1, 0], [0, 0]], [[2, 0], [0, 0]], + [[128, 0], [0, 0]], [[128, 1], [0, 0]], + [[128, 128], [0, 0]], [[128, 128], [1, 0]]])), + ] +) +# yapf: enable +def test_get_cp_local_seq_lens( + seq_lens, + pcp_world_size, + dcp_world_size, + cp_kv_cache_interleave_size, + target, +): + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.speculative_config.num_speculative_tokens=0 + pcp_manager = PCPManager(pcp_world_size=pcp_world_size, + pcp_rank=0, + dcp_world_size=dcp_world_size, + dcp_rank=0, + max_buffer_num_tokens=10000, + max_num_reqs=1000, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + ret = pcp_manager._get_cp_local_seq_lens(seq_lens, + pcp_world_size, dcp_world_size, + cp_kv_cache_interleave_size) + assert torch.equal(ret, target) + +# yapf: disable +@pytest.mark.parametrize( + "req_ids, num_computed_tokens," \ + "token_ids_tensor_list," \ + "num_reqs, total_num_scheduled_tokens, num_scheduled_tokens," \ + "target_input_ids_pcp_full, target_query_start_loc_pcp_full", + [ + # prefill + ( + ['0'], np.array([0]), + [torch.tensor([0, 671, 6102, 294, 8760, 344])], + 1, 6, {'0': 6}, + torch.tensor([0, 671, 6102, 294, 8760, 344]), + torch.tensor([0, 6]) + ), + # decode + ( + ['0'], np.array([6]), + [torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0])], + 1, 2, {'0': 2}, + torch.tensor([88907, 0]), + torch.tensor([0, 2]) + ), + # decode + prefill + ( + ['0', '1'], np.array([6, 0]), + [ + torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), + torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), + ], + 2, 12, {'0': 2, '1': 10}, + torch.tensor([88907, 0, 0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 1030]), + torch.tensor([0, 2, 12]) + ), + # decodes + prefills + ( + ['0', '1', '2', '3'], np.array([6, 8, 0, 0]), + [ + torch.tensor([0, 671, 6102, 294, 8760, 344, 88907, 0]), + torch.tensor([0, 19923, 14, 1026, 2329, 344, 9807, 14, 342, 0]), + torch.tensor([0, 671, 8749, 294, 3702, 4106, 344, 88907]), + torch.tensor([0, 671, 5335, 1469, 7539, 305, 6397]), + ], + 4, 19, {'0': 2, '1': 2, '2': 8, '3': 7}, + torch.tensor([88907, 0, 342, 0, 0, 671, 8749, 294, 3702, 4106, 344, 88907, + 0, 671, 5335, 1469, 7539, 305, 6397]), + torch.tensor([0, 2, 4, 12, 19]) + ), + ]) +# yapf: enable +def test_generate_pcp_mtp_input( + req_ids, + num_computed_tokens, + token_ids_tensor_list, + num_reqs, + total_num_scheduled_tokens, + num_scheduled_tokens, + target_input_ids_pcp_full, + target_query_start_loc_pcp_full, +): + max_num_reqs = 4 + max_model_len = 4096 + max_num_tokens = 4096 + vllm_config = MagicMock() + vllm_config.model_config = MagicMock() + vllm_config.speculative_config.num_speculative_tokens=1 + vllm_config.scheduler_config.max_num_seqs = max_num_reqs + vllm_config.scheduler_config.max_num_batched_tokens = max_model_len + pcp_manager = PCPManager(pcp_world_size=2, + pcp_rank=0, + dcp_world_size=1, + dcp_rank=0, + max_buffer_num_tokens=max_num_tokens, + max_num_reqs=max_num_reqs, + device="cpu", + vllm_config=vllm_config, + pin_memory=False) + arange_np = np.arange(max_model_len) + input_batch = MagicMock() + input_batch.num_computed_tokens_cpu = \ + np.zeros(max_num_reqs, dtype=np.int32) + token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + ) + input_batch.token_ids_cpu_tensor = token_ids_cpu_tensor + input_batch.token_ids_cpu = token_ids_cpu_tensor.numpy() + token_ids_cpu_tensor = input_batch.token_ids_cpu_tensor + + # Set input_batch + input_batch.req_ids = req_ids + input_batch.num_computed_tokens_cpu[:num_computed_tokens. + size] = num_computed_tokens + for i, token_ids_tensor in enumerate(token_ids_tensor_list): + token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor + + pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens, num_scheduled_tokens, input_batch, arange_np) + assert torch.equal( + pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], + target_input_ids_pcp_full) + assert torch.equal(pcp_manager.query_start_loc_pcp_full.cpu[:num_reqs + 1], + target_query_start_loc_pcp_full) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 48784c47b34..8daa964febc 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -26,15 +26,12 @@ from threading import Lock from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union -import numpy as np import torch import torch_npu # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event from vllm.logger import logger from vllm.sequence import IntermediateTensors -from vllm.utils.math_utils import cdiv -from vllm.v1.utils import CpuGpuBuffer import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -1103,588 +1100,4 @@ def dispose_layer(layer: Any): def replace_layer(original_layer: Any, new_layer: Any): original_layer.__class__ = new_layer.__class__ - original_layer.__dict__ = new_layer.__dict__ - - -class PCPManager: - """ - Manager for Prefill Context Parallelism (PCP) metadata and buffers. - - This manager encapsulates all PCP-related buffers and logic so that the - ModelRunner can access them via `self.pcp_manager`. - """ - - def __init__( - self, - pcp_world_size: int, - pcp_rank: int, - dcp_world_size: int, - dcp_rank: int, - max_buffer_num_tokens: int, - max_num_reqs: int, - device: torch.device, - vllm_config: VllmConfig, - pin_memory: bool = False, - ) -> None: - self.pcp_world_size = pcp_world_size - self.pcp_world_rank = pcp_rank - self.dcp_world_size = dcp_world_size - self.dcp_world_rank = dcp_rank - self.speculative_config = vllm_config.speculative_config - self.decode_threshold = 1 + ( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0) - self.vllm_config = vllm_config - self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens - self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs - self.device = device - self.pcp_allgather_restore_idx = CpuGpuBuffer( - max_buffer_num_tokens, - dtype=torch.int64, - device=device, - pin_memory=pin_memory, - ) - self.pcp_padded_slot_mapping = torch.full( - (max_buffer_num_tokens, ), fill_value=-1, - dtype=torch.int32, - device=device, - ) - self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ), - device="cpu", - dtype=torch.int64) - self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() - self.pcp_unpad_mask_cpu_tensor = torch.zeros( - (max_buffer_num_tokens, ), - device="cpu", - dtype=torch.bool, - ) - self.num_actual_tokens_pcp_padded = 0 - self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() - self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ - [] for _ in range(self.pcp_world_size) - ] - self.full_indices = list( - range(self.max_num_tokens * self.pcp_world_size * - self.dcp_world_size + self.pcp_world_size * - self.dcp_world_size * self.max_num_reqs)) - if self.speculative_config and self.pcp_world_size > 1: - self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens, - device=device, - pin_memory=pin_memory, - dtype=torch.int32) - self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1, - device=device, - pin_memory=pin_memory, - dtype=torch.int32) - self.positions_pcp_full = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=True) - self.positions_pcp_full_np = self.positions_pcp_full.numpy() - - def _get_cumsum_and_arange( - self, - num_scheduled_tokens: np.ndarray, - arange_np: np.ndarray, - cumsum_dtype: np.dtype | None = None, - ) -> tuple[np.ndarray, np.ndarray]: - """Get the cumulative sum and batched arange of the given array. - # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) - # Equivalent to but faster than: - # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) - """ - # Step 1. [2, 5, 3] -> [2, 7, 10] - cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype) - total_num_tokens = cu_num_tokens[-1] - # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] - cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, - num_scheduled_tokens) - # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange = arange_np[:total_num_tokens] - cumsums_offsets - - return cu_num_tokens, arange - - def update_tokens_for_pcp( - self, - num_scheduled_tokens: np.ndarray, - arange_np: np.ndarray, - num_reqs: int, - reorder_batch_threshold: int | None = None, - ) -> tuple[np.ndarray, np.ndarray]: - """ - Update token counts and positions for Prefill Context Parallelism (PCP). - - When using Prefill Context Parallelism, each request's prefill sequence is - split across multiple PCP ranks. The splitting strategy used here is the - "DualChunkSwap" style: each request's (padded) sequence is split into - 2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved - head/tail pattern to balance load. - - This function: - - Computes how many tokens each request should be processed by the current - PCP rank (pcp_tokens). - - Computes the flattened positions of those tokens within the local - padded buffer (pcp_positions). - - Updates runner state arrays used to restore original order and mask out - padded tokens after allgather: - - self.num_pcp_pads_cpu: number of pads added per request - - self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the - padded allgather buffer - - self.pcp_allgather_restore_idx: index array used to restore original - ordering after per-rank allgather and interleaving. - - Args: - num_scheduled_tokens: 1D numpy array of length num_reqs containing - the number of new tokens scheduled per request. - arange_np: 1D numpy array of length max_buffer_num_tokens used for - efficient batched arange operations. - num_reqs: Total number of requests in the batch. - reorder_batch_threshold: Threshold for decode vs prefill requests. - - Returns: - Tuple (pcp_tokens, pcp_positions): - - pcp_tokens: number of tokens per request that this PCP rank will - actually process (after splitting / replication). - - pcp_positions: flattened positions for those tokens on this rank, - used to build the positions buffer for the model. - - Example: - >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. - >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) - >>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) - >>> Meanwhile, the following results are same for each pcp rank - >>> self.num_pcp_pads_cpu - [1, 3, 0] - >>> self.pcp_unpad_mask_cpu - [True, False, True, True, True, True, True, False, False, - False, True, True, True, True, True, True, True, True] - >>> self.pcp_allgather_resotre_idx - [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] - """ - - assert reorder_batch_threshold is not None, ( - "PCP depends on reorder batch to split decode and prefill requests." - ) - num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold) - num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs]) - - # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). - # We first pad each request's token count up to that multiple. - num_padded_scheduled_tokens = np.ceil( - num_scheduled_tokens / (2 * self.pcp_world_size)).astype( - np.int32) * (2 * self.pcp_world_size) - - # PCP does not split decode requests. For decode requests, we instead - # duplicate the scheduled tokens across the pcp_world_size ranks. - num_padded_scheduled_tokens[:num_decode_reqs] = ( - num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size) - - # Record how many pads were added per request (padded - original). - self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens - - num_scheduled_tokens) - - # cu_padded_tokens: cumulative sum of padded token counts, - # pcp_padded_arange: per-request arange flattened for padded tokens. - cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( - num_padded_scheduled_tokens, arange_np) - # Build the mask that marks which positions in the padded allgather buffer - # correspond to real (unpadded) tokens. - self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( - pcp_padded_arange < np.repeat(num_scheduled_tokens, - num_padded_scheduled_tokens)) - - pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size - - # Compute per-request "chunk sizes" for the head/tail splitting. - # For prefill requests, we further split the pcp_tokens into two chunks - # (head and tail). For decode requests, the chunk equals pcp_tokens. - pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) - pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] - - # Build arange-style helpers for pcp tokens and chunk sizes: - # - pcp_arange gives indices repeated for each token in pcp_tokens - # - pcp_chunk_arange gives indices repeated for each position inside chunks - _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np) - _, pcp_chunk_arange = self._get_cumsum_and_arange( - pcp_chunk_sizes, arange_np) - - # Mask that marks whether a position belongs to the head chunk (True) - # or the tail chunk (False). For decode requests, tail chunk won't exist - # and is handled specially below. - pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, - pcp_tokens) - - def get_current_rank_positions(positions_start_loc: int | np.ndarray, - rank: int): - """ - Compute flattened positions for the given rank with a given start - offset for each request (positions_start_loc). - - - For head chunks: start at positions_start_loc + rank * chunk_size. - - For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank - - 1) * chunk_size. - - For decode requests: no tail chunks; their positions are filled from the - contiguous (unpadded) `tokens` arange instead (handled after). - """ - positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) - head_start_loc = positions_start_loc + rank * pcp_chunk_sizes - tail_start_loc = ( - positions_start_loc + - (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes) - # Fill head positions using chunk arange offset by head_start_loc. - positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( - head_start_loc, pcp_chunk_sizes) - # Fill tail positions. Note decode requests do not have tail chunks, - # so the tail filling is only for prefill positions. - positions[~pcp_head_chunk_mask] = ( - pcp_chunk_arange[num_decode_tokens:] + - np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]) - return positions - - positions = get_current_rank_positions(0, self.pcp_world_rank) - # Decode tokens are duplicated only after AG. But their positions are - # same without prefill context parallel. - if num_decode_reqs > 0: - positions[:num_decode_tokens] = self._get_cumsum_and_arange( - num_scheduled_tokens[:num_decode_reqs], arange_np)[1] - - # Build the restore index used after allgather. - padded_pos_start_loc = np.roll(cu_padded_tokens, 1) - padded_pos_start_loc[0] = 0 - all_positions_lst = [ - get_current_rank_positions(padded_pos_start_loc, rank_i) - for rank_i in range(self.pcp_world_size) - ] - all_positions = np.concatenate(all_positions_lst) - self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = ( - all_positions.argsort()) - self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) - - return ( - pcp_tokens[:num_reqs], - positions, - ) - - def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int): - return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size - - self.num_pcp_pads_cpu_tensor[:num_reqs] - 1) - - def get_discard_request_mask( - self, - num_computed_tokens_cpu: np.ndarray, - num_scheduled_tokens: np.ndarray, - num_reqs: int, - num_tokens_np: np.ndarray, - ): - return (num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens * self.pcp_world_size - - self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np - - def get_padded_slot_mapping(self, num_tokens: int, - slot_mapping: torch.Tensor): - # After pcp allgather and restore, there are padded tokens in kv, - # so we need pad slotmapping for alignment. - pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens * - self. - pcp_world_size] - cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * - self.pcp_world_size] - pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping - return pcp_padded_slot_mapping - - def get_restore_hidden_states( - self, - hidden_states: torch.Tensor, - ): - # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx - # ignores the padding from CUDA Graph. - from vllm.distributed.parallel_state import get_pcp_group - hidden_states = get_pcp_group().all_gather( - hidden_states[:self.num_actual_tokens_pcp_padded // - self.pcp_world_size], - 0, - ) - restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states. - shape[0]] - return torch.index_select( - hidden_states, - 0, - restore_idx, - ) - - def generate_pcp_mtp_input( - self, - num_reqs: int, - total_num_scheduled_tokens: int, - num_scheduled_tokens: dict[str, int], - input_batch, - arange_np: np.ndarray, - ): - """ - While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, - but mtp need to shift original input_ids before pcp splitting, - so we record original input_ids here. - """ - total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens - num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) - for i, req_id in enumerate(input_batch.req_ids): - num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] - req_indices_pcp_full = np.repeat(arange_np[:num_reqs], - num_scheduled_tokens_pcp_full) - cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) - self.query_start_loc_pcp_full.np[0] = 0 - self.query_start_loc_pcp_full.np[1:num_reqs + - 1] = cu_num_tokens_pcp_full - self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) - cumsums_offsets_pcp_full = np.repeat( - cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, - num_scheduled_tokens_pcp_full) - arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full - positions_pcp_full_np = self.positions_pcp_full_np[: - total_num_scheduled_tokens_pcp_full] - np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], - arange_pcp_full, - out=positions_pcp_full_np) - token_indices_pcp_full = ( - positions_pcp_full_np + - req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]) - torch.index_select(input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices_pcp_full), - out=self.input_ids_pcp_full. - cpu[:total_num_scheduled_tokens_pcp_full]) - self.query_start_loc_pcp_full.copy_to_gpu(num_reqs + 1) - self.input_ids_pcp_full.copy_to_gpu( - total_num_scheduled_tokens_pcp_full) - - def _get_cp_local_seq_lens( - self, - seq_lens: torch.Tensor, - pcp_world_size: int = 1, - dcp_world_size: int = 1, - cp_kv_cache_interleave_size: int = 1, - ) -> torch.Tensor: - """While using pcp or dcp, kv_cache size stored on each rank may be different, - use this function to calculate split decode seq_lens of each (p/d)cp rank. - """ - num_requests = seq_lens.size(0) - total_world_size = pcp_world_size * dcp_world_size - seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) - rank_offsets = (torch.arange(total_world_size, - dtype=torch.int32).unsqueeze(0).repeat( - num_requests, 1)) - base = (seq_lens_tiled // cp_kv_cache_interleave_size // - total_world_size * cp_kv_cache_interleave_size) - remainder = seq_lens_tiled - base * total_world_size - remainder = torch.clip( - remainder - rank_offsets * cp_kv_cache_interleave_size, - 0, - cp_kv_cache_interleave_size, - ) - dcp_local_seq_lens = (base + remainder).reshape( - [-1, pcp_world_size, dcp_world_size]) - return dcp_local_seq_lens - - def generate_kv_idx(self, scheduler_output, input_batch): - if not self.pcp_world_size > 1: - return - self.cp_kv_recover_idx_for_chunk = [[] - for _ in range(self.pcp_world_size) - ] - - for i, req_id in enumerate(input_batch.req_ids): - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - is_prefill = input_batch.num_computed_tokens_cpu[ - i] < input_batch.num_prompt_tokens[i] - if is_prefill: - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_tokens, - 2 * self.pcp_world_size) * (2 * self.pcp_world_size) - chunk_size = num_cp_padded_scheduled_tokens // ( - 2 * self.pcp_world_size) - num_added_recover_tokens = len( - self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size - for rank in range(self.pcp_world_size): - self.cp_kv_recover_idx_for_chunk[rank].extend( - self.full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * - chunk_size + - num_added_recover_tokens]) - self.cp_kv_recover_idx_for_chunk[rank].extend( - self.full_indices[num_cp_padded_scheduled_tokens - - (rank + 1) * chunk_size + - num_added_recover_tokens: - num_cp_padded_scheduled_tokens - - rank * chunk_size + - num_added_recover_tokens]) - - cp_kv_recover_idx_for_chunk = torch.from_numpy( - np.concatenate(self.cp_kv_recover_idx_for_chunk)).to( - device=self.device) - cp_kv_recover_idx_for_chunk.copy_(torch.tensor( - np.array( - self.cp_kv_recover_idx_for_chunk).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( - torch.float32).argsort().to(torch.int32) - - def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, - attn_mask, input_batch): - from vllm_ascend.attention.utils import \ - AscendPrefillContextParallelMetadata - num_reqs = input_batch.num_reqs or query_lens.size(0) - num_decodes = sum(input_batch.num_computed_tokens_cpu[:num_reqs] >= - input_batch.num_prompt_tokens[:num_reqs]) - num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size - self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded - long_seq_metadata = None - if self.pcp_world_size * self.dcp_world_size > 1: - decode_context_lens = input_batch.num_tokens[:num_decodes] - prefill_context_lens = input_batch.num_computed_tokens_cpu[ - num_decodes:num_reqs] - context_lens = np.concatenate( - [decode_context_lens, prefill_context_lens]) - num_computed_tokens_of_pcp_dcp = torch.zeros( - [ - num_reqs * self.decode_threshold, self.pcp_world_size, - self.dcp_world_size - ], - dtype=torch.int32, - ) - # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape - for decode_idx in range(self.decode_threshold): - num_computed_tokens_of_pcp_dcp[ - self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ - self._get_cp_local_seq_lens( - torch.tensor(context_lens), - self.pcp_world_size, - self.dcp_world_size, - self.vllm_config.parallel_config.cp_kv_cache_interleave_size, - ) - long_seq_metadata = AscendPrefillContextParallelMetadata( - num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, - num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. - numpy()) - if self.pcp_world_size > 1: - q_head_idx, q_tail_idx = [], [] - kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] - kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - chunk_seqlens = [] - kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] - q_req_offset = 0 - kv_req_offset = 0 - q_head_chunk_id = self.pcp_world_rank - q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank - for i, seq_len in enumerate(query_lens): - if i < num_decodes: - continue - chunk_len = seq_len // 2 - chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) - kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_head_chunk_id))) - kv_with_q_head_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) - - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) - kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_tail_chunk_id))) - kv_with_q_tail_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) - - q_req_offset += seq_len - kv_req_offset += seq_len * self.pcp_world_size - - # Convert lists to tensors and move to device - def _list_to_tensor(lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), - dtype=dtype, - device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), - non_blocking=True) - return tensor_npu - - q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) - q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) - self.q_head_idx_tensor = q_head_idx_tensor - self.q_tail_idx_tensor = q_tail_idx_tensor - - q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to( - torch.int32) - self.q_full_idx = q_full_idx - - self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': - kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': - kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx - } - for key, value in self.kv_idx_names.items(): - tensor_npu = _list_to_tensor(value, self.device) - self.kv_idx_names[key] = tensor_npu - - attn_mask_seqlens = torch.tensor( - [chunk_seqlens, chunk_seqlens], dtype=torch.int32) - head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) - tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) - pcp_prefill_mask = attn_mask - - self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'pcp_prefill_mask': pcp_prefill_mask - } - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: - num_actual_tokens_pcp_padded] - long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk - long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor - long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor - long_seq_metadata.q_full_idx = self.q_full_idx - long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_mask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_mask_idx_tensor'] - long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ - 'attn_mask_seqlens'] - long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'head_attn_nomask_seqlens'] - long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'tail_attn_nomask_seqlens'] - long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ - 'pcp_prefill_mask'] - self.long_seq_metadata = long_seq_metadata - return long_seq_metadata + original_layer.__dict__ = new_layer.__dict__ \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7f3d25c6c2d..8803b8ea171 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -112,11 +112,11 @@ from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer -from vllm_ascend.utils import (AscendDeviceType, PCPManager, - ProfileExecuteDuration, enable_sp, - get_ascend_device_type, is_moe_model, +from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, + enable_sp, get_ascend_device_type, is_moe_model, lmhead_tp_enable, maybe_trans_nz) from vllm_ascend.worker.npu_input_batch import NPUInputBatch +from vllm_ascend.worker.pcp_utils import PCPManager from vllm_ascend.ascend_forward_context import ( # isort: skip MoECommType, get_mc2_tokens_capacity, select_moe_comm_method, @@ -287,32 +287,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): set_mc2_tokens_capacity(vllm_config, self.max_num_reqs, self.uniform_decode_query_len) set_mc2_mask(vllm_config, self.device) - self.pcp_allgather_restore_idx = torch.zeros( - self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ - [] for _ in range(self.pcp_size) - ] - - self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32) - self.pcp_padded_slot_mapping = torch.zeros( - self.max_num_tokens + 2 * self.pcp_size * self.max_num_reqs, - dtype=torch.int32, - device=self.device) - self.num_actual_tokens_pcp_padded = 0 - if self.speculative_config and self.pcp_size * self.dcp_size > 1: - self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.query_start_loc_pcp_full = self._make_buffer( - self.max_num_reqs + 1, dtype=torch.int32) - self.positions_pcp_full = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=True) - self.positions_pcp_full_np = self.positions_pcp_full.numpy() - self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) self.decode_threshold = 1 + ( self.speculative_config.num_speculative_tokens if self.speculative_config else 0) @@ -536,14 +510,29 @@ def _prepare_inputs( req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) - # for pcp, prefill mtp should use origin scheduleroutput , - if self.speculative_config and self.pcp_size > 1: - self.pcp_manager.generate_pcp_mtp_input( - num_reqs, total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens, self.input_batch, - self.arange_np) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + if not scheduler_output.scheduled_spec_decode_tokens: + num_valid_tokens = np.array(tokens, dtype=np.int32) + else: + num_valid_tokens = np.array([ + num_tokens - + len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) + for num_tokens, i in zip(tokens, req_ids) + ], + dtype=np.int32) + # Get the attention state. + attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, + num_valid_tokens) + self.attn_state = attn_state # type: ignore + + # Determine if it's a splitfuse batch + with_prefill = attn_state not in [ + AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding + ] + self.attn_mask = self._make_attention_mask(attn_state) + # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] cu_num_tokens, arange = self._get_cumsum_and_arange( @@ -556,32 +545,38 @@ def _prepare_inputs( req_indices, positions_np) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) + # for pcp, prefill mtp should use origin scheduleroutput , + if self.speculative_config and self.pcp_size > 1: + self.pcp_manager.generate_pcp_mtp_input( + num_reqs, total_num_scheduled_tokens, + scheduler_output.num_scheduled_tokens, with_prefill, + req_indices, positions_np, cu_num_tokens) if self.pcp_size > 1: if not self.vllm_config.model_config.use_mla: - self.generate_kv_idx(scheduler_output) - tokens_before_update = tokens.copy() - tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( - tokens) - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - total_num_pcp_pads = torch.sum(self.num_pcp_pads).item() - else: - position_pcp, pcp_unpad_mask = None, None - self.num_pcp_pads = self.num_pcp_pads[:num_reqs] - + self.pcp_manager.generate_kv_idx(scheduler_output, + self.input_batch) + num_scheduled_tokens[: + num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( + num_scheduled_tokens[:num_reqs], + self.arange_np, + self.input_batch.num_reqs, + self.reorder_batch_threshold, + ) + # Re-update after PCP split sequences. + total_num_scheduled_tokens = sum(num_scheduled_tokens) + scheduler_output.total_num_scheduled_tokens = total_num_scheduled_tokens + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + cu_num_tokens, _ = self._get_cumsum_and_arange( + num_scheduled_tokens) + positions_np = self.positions.np[:total_num_scheduled_tokens] + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + position_pcp[:total_num_scheduled_tokens], + out=positions_np, + ) max_num_scheduled_tokens = max(tokens) - if not scheduler_output.scheduled_spec_decode_tokens: - num_valid_tokens = np.array(tokens, dtype=np.int32) - else: - num_valid_tokens = np.array([ - num_tokens - - len(scheduler_output.scheduled_spec_decode_tokens.get(i, [])) - for num_tokens, i in zip((tokens_before_update if self. - pcp_size > 1 else tokens), req_ids) - ], - dtype=np.int32) - if (self.use_aclgraph and total_num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # Add padding to the batch size. @@ -598,18 +593,6 @@ def _prepare_inputs( else: # Eager mode. num_input_tokens = total_num_scheduled_tokens - - # Get the attention state. - attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, - num_valid_tokens) - self.attn_state = attn_state # type: ignore - - # Determine if it's a splitfuse batch - with_prefill = attn_state not in [ - AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding - ] - self.attn_mask = self._make_attention_mask(attn_state) - self.query_lens = torch.from_numpy(num_scheduled_tokens) # Get info across DP ranks. @@ -870,8 +853,10 @@ def _prepare_inputs( >= self.input_batch.num_prompt_tokens[req_idx]) else -1) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens, - self.num_pcp_pads[:num_reqs].numpy()) + num_draft_tokens, + cu_num_tokens, + num_pcp_pads=self.pcp_manager.num_pcp_pads_cpu[:num_reqs] + if self.pcp_size > 1 else None) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). @@ -893,14 +878,6 @@ def _prepare_inputs( self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() - if self.speculative_config and self.pcp_size * self.dcp_size > 1: - self._generate_pcp_mtp_input( - num_reqs, scheduler_output.total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens, with_prefill, - req_indices, positions_np, cu_num_tokens) - - long_seq_metadata = self._generate_pcp_metadata( - total_num_scheduled_tokens) # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -926,13 +903,11 @@ def _prepare_inputs( sum(self.pcp_manager.num_pcp_pads_cpu[:num_reqs])) blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor() - # blk_table.slot_mapping.gpu[slot_mapping_size:].fill_(-1) slot_mapping = blk_table.slot_mapping.gpu[: maybe_pcp_full_tokens] if self.pcp_size == 1: slot_mapping[ total_num_scheduled_tokens:num_input_tokens].fill_(-1) - # blk_table_tensor[num_reqs:].fill_(-1) slot_mapping = blk_table.slot_mapping.gpu if self.pcp_size > 1: self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( @@ -943,9 +918,8 @@ def _prepare_inputs( total_num_scheduled_tokens, slot_mapping, ) - blk_table.slot_mapping.gpu[:self.pcp_manager.num_actual_tokens_pcp_padded] = slot_mapping - - + blk_table.slot_mapping.gpu[:self.pcp_manager. + num_actual_tokens_pcp_padded] = slot_mapping # NOTE: This is a temporary hack, now in GPUModelRunner, this prepare_inputs # has been split to multiple parts, and there are 3 parts that is related to this @@ -1008,13 +982,12 @@ def _prepare_inputs( # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens_cpu = self.query_lens_pcp_full.cpu[:num_reqs] - ori_query_lens = self.query_lens_pcp_full.gpu[:num_reqs] + ori_query_lens = self.pcp_manager.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ + self.pcp_manager.query_start_loc_pcp_full_cpu[:num_reqs] num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() num_decode_reqs = num_reqs - num_prefill_reqs - num_decode_reqs_flatten = \ - ori_query_lens_cpu[:num_decode_reqs].sum().item() + num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold blk_table_tensor[ num_decode_reqs_flatten:num_decode_reqs_flatten + num_prefill_reqs].copy_( @@ -1025,12 +998,6 @@ def _prepare_inputs( ori_query_lens[:num_decode_reqs], dim=0)) common_attn_metadata.block_table_tensor = \ blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] - long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu - if 'pad_size' in locals() and pad_size > 0: - ori_query_lens_cpu[-pad_size:] = \ - torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) - long_seq_metadata.max_query_len_pcp_full = \ - ori_query_lens_cpu.max().item() if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: @@ -2990,365 +2957,6 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, npu_graph_size / (1 << 30)) - def _update_tokens_for_pcp(self, tokens): - num_reqs = self.input_batch.num_reqs - self.num_pcp_pads = self.num_pcp_pads[:num_reqs] - tokens = np.array(tokens, dtype=np.int32) - num_decode_reqs = (np.array(tokens) <= self.decode_threshold).sum() - num_decode_tokens = sum(tokens[:num_decode_reqs]) - num_padded_scheduled_tokens = np.ceil( - tokens / - (2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size) - num_padded_scheduled_tokens[:num_decode_reqs] = ( - tokens[:num_decode_reqs] * self.pcp_size) - self.num_pcp_pads = torch.tensor(num_padded_scheduled_tokens - tokens) - cu_padded_tokens, pcp_padded_arange = \ - self._get_cumsum_and_arange(num_padded_scheduled_tokens) - unpad_mask = torch.from_numpy( - pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens)) - unpad_mask_decode = unpad_mask[:num_decode_tokens * self.pcp_size] - unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_size]) - unpad_mask_decode[:, 0] = True - unpad_mask_decode[:, 1:] = False - - pcp_tokens = num_padded_scheduled_tokens // self.pcp_size - pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) - pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] - _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens) - _, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes) - pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, - pcp_tokens) - - def get_current_rank_positions(cu_tokens, rank): - positions_start_loc = np.zeros_like(cu_tokens) - positions_start_loc[1:] = cu_tokens[:-1] - positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) - head_start_loc = positions_start_loc + rank * pcp_chunk_sizes - tail_start_loc = positions_start_loc + \ - (2 * self.pcp_size - rank - 1) * pcp_chunk_sizes - positions[pcp_head_chunk_mask] = pcp_chunk_arange + \ - np.repeat(head_start_loc, pcp_chunk_sizes) - # Decode reqs do not have tail chunks. - positions[~pcp_head_chunk_mask] = \ - pcp_chunk_arange[num_decode_tokens:] + \ - np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:] - return positions - - positions = get_current_rank_positions( - np.zeros(num_reqs, dtype=np.int32), self.pcp_rank) - # Decode tokens are duplicate and their positions always be 0. - if num_decode_reqs > 0: - positions[:num_decode_tokens] = self._get_cumsum_and_arange( - tokens[:num_decode_reqs])[1] - - all_positions = [ - get_current_rank_positions(cu_padded_tokens, rank_i) - for rank_i in range(self.pcp_size) - ] - all_positions_tensor = torch.from_numpy(np.concatenate(all_positions)) - self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_( - all_positions_tensor.float().argsort().long(), non_blocking=True) - return pcp_tokens, positions, unpad_mask - - def _get_cp_local_seq_lens( - self, - seq_lens: torch.Tensor, - pcp_world_size: int = 1, - dcp_world_size: int = 1, - cp_kv_cache_interleave_size: int = 1, - ) -> torch.Tensor: - """While using pcp or dcp, kv_cache size stored on each rank may be different, - use this function to calculate split decode seq_lens of each (p/d)cp rank. - """ - num_requests = seq_lens.size(0) - total_world_size = pcp_world_size * dcp_world_size - seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) - rank_offsets = (torch.arange(total_world_size, - dtype=torch.int32).unsqueeze(0).repeat( - num_requests, 1)) - base = (seq_lens_tiled // cp_kv_cache_interleave_size // - total_world_size * cp_kv_cache_interleave_size) - remainder = seq_lens_tiled - base * total_world_size - remainder = torch.clip( - remainder - rank_offsets * cp_kv_cache_interleave_size, - 0, - cp_kv_cache_interleave_size, - ) - dcp_local_seq_lens = (base + remainder).reshape( - [-1, pcp_world_size, dcp_world_size]) - return dcp_local_seq_lens - - def _generate_pcp_metadata(self, total_num_scheduled_tokens): - # In dummy run num_reqs == 0, update it from seq_lens - num_reqs = self.input_batch.num_reqs or self.query_lens.size(0) - query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \ - if self.pcp_size > 1 and self.speculative_config else self.query_lens - num_decodes = (query_lens <= self.decode_threshold).sum().item() - num_prefills = num_reqs - num_decodes - num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size - self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded - long_seq_metadata = None - if self.pcp_size * self.dcp_size > 1: - decode_context_lens = self.input_batch.num_tokens[:num_decodes] - prefill_context_lens = self.input_batch.num_computed_tokens_cpu[ - num_decodes:num_reqs] - context_lens = np.concatenate( - [decode_context_lens, prefill_context_lens]) - num_computed_tokens_of_pcp_dcp = torch.zeros( - [ - num_reqs * self.decode_threshold, self.pcp_size, - self.dcp_size - ], - dtype=torch.int32, - ) - # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape. - # Same as block_table, we flatten decode seq_lens to query_lens, - # and keep prefill seq_lens unchanged. - for decode_idx in range(self.decode_threshold): - num_computed_tokens_of_pcp_dcp[ - self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ - self._get_cp_local_seq_lens( - torch.tensor(context_lens) - decode_idx, - self.pcp_size, - self.dcp_size, - self.parallel_config.cp_kv_cache_interleave_size, - ) - if self.decode_threshold > 1: - num_computed_tokens_of_pcp_dcp_list = [] - if num_decodes: - num_decodes_flatten = \ - self.query_lens[:num_decodes].sum().item() - if self.query_lens[:num_decodes].min().item( - ) == self.decode_threshold: - decode_flatten_idx = list(range(num_decodes_flatten)) - else: - decode_flatten_idx = [] - for req_id in range(num_decodes): - offset = (req_id + 1) * self.decode_threshold - decode_flatten_idx += \ - list(range(offset - self.query_lens[req_id], offset)) - num_computed_tokens_of_pcp_dcp_list.append( - num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) - if num_prefills: - num_computed_tokens_of_pcp_dcp_list.append( - num_computed_tokens_of_pcp_dcp[ - (num_decodes + 1) * self.decode_threshold - - 1::self.decode_threshold]) - num_computed_tokens_of_pcp_dcp = torch.cat( - num_computed_tokens_of_pcp_dcp_list, dim=0) - long_seq_metadata = AscendPrefillContextParallelMetadata( - num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, - num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. - numpy()) - if self.pcp_size > 1: - q_head_idx, q_tail_idx = [], [] - kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] - kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] - chunk_seqlens = [] - kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] - q_req_offset = 0 - kv_req_offset = 0 - q_head_chunk_id = self.pcp_rank - q_tail_chunk_id = self.pcp_size * 2 - 1 - self.pcp_rank - for i, seq_len in enumerate(self.query_lens): - if i < num_decodes: - continue - chunk_len = seq_len // 2 - chunk_seqlens.append(chunk_len) - q_head_idx.extend( - list(range(q_req_offset, q_req_offset + chunk_len))) - kv_with_q_head_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_head_chunk_id))) - kv_with_q_head_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_head_chunk_id, - kv_req_offset + chunk_len * - (q_head_chunk_id + 1)))) - kv_with_q_head_nomask_seqlens.append(chunk_len * - q_head_chunk_id) - - q_tail_idx.extend( - list( - range(q_req_offset + chunk_len, - q_req_offset + chunk_len * 2))) - kv_with_q_tail_nomask_idx.extend( - list( - range(kv_req_offset, kv_req_offset + - chunk_len * q_tail_chunk_id))) - kv_with_q_tail_mask_idx.extend( - list( - range( - kv_req_offset + chunk_len * q_tail_chunk_id, - kv_req_offset + chunk_len * - (q_tail_chunk_id + 1)))) - kv_with_q_tail_nomask_seqlens.append(chunk_len * - q_tail_chunk_id) - - q_req_offset += seq_len - kv_req_offset += seq_len * self.pcp_size - - # Convert lists to tensors and move to device - def _list_to_tensor(lst, device, dtype=torch.int32): - tensor_npu = torch.zeros(len(lst), - dtype=dtype, - device=device) - tensor_npu.copy_(torch.tensor(lst, dtype=dtype), - non_blocking=True) - return tensor_npu - - q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) - q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) - self.q_head_idx_tensor = q_head_idx_tensor - self.q_tail_idx_tensor = q_tail_idx_tensor - - q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) - q_full_idx = q_full_idx.to(torch.float32).argsort().to( - torch.int32) - self.q_full_idx = q_full_idx - - self.kv_idx_names = { - 'kv_with_q_head_nomask_idx_tensor': - kv_with_q_head_nomask_idx, - 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, - 'kv_with_q_tail_nomask_idx_tensor': - kv_with_q_tail_nomask_idx, - 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx - } - for key, value in self.kv_idx_names.items(): - tensor_npu = _list_to_tensor(value, self.device) - self.kv_idx_names[key] = tensor_npu - - attn_mask_seqlens = torch.tensor( - [chunk_seqlens, chunk_seqlens], dtype=torch.int32) - head_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_head_nomask_seqlens], - dtype=torch.int32) - tail_attn_nomask_seqlens = torch.tensor( - [chunk_seqlens, kv_with_q_tail_nomask_seqlens], - dtype=torch.int32) - pcp_prefill_mask = self.attn_mask - - self.extra_long_seq_kwargs = { - 'attn_mask_seqlens': attn_mask_seqlens, - 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, - 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, - 'pcp_prefill_mask': pcp_prefill_mask - } - long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[: - num_actual_tokens_pcp_padded] - long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk - long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor - long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor - long_seq_metadata.q_full_idx = self.q_full_idx - long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_head_mask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_nomask_idx_tensor'] - long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ - 'kv_with_q_tail_mask_idx_tensor'] - long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ - 'attn_mask_seqlens'] - long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'head_attn_nomask_seqlens'] - long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ - 'tail_attn_nomask_seqlens'] - long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ - 'pcp_prefill_mask'] - self.long_seq_metadata = long_seq_metadata - return long_seq_metadata - - def _generate_pcp_mtp_input( - self, - num_reqs: int, - total_num_scheduled_tokens: int, - num_scheduled_tokens: dict[str, int], - with_prefill: bool = True, - req_indices=None, - positions_np=None, - cu_num_tokens=None, - ): - """ - While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, - but mtp need to shift original input_ids before pcp splitting, - so we record original input_ids here. - """ - total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens - num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) - for i, req_id in enumerate(self.input_batch.req_ids): - num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] - self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy( - num_scheduled_tokens_pcp_full) - req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens_pcp_full) - cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) - self.query_start_loc_pcp_full.np[0] = 0 - self.query_start_loc_pcp_full.np[1:num_reqs + - 1] = cu_num_tokens_pcp_full - self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) - cumsums_offsets_pcp_full = np.repeat( - cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, - num_scheduled_tokens_pcp_full) - arange_pcp_full = self.arange_np[: - total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full - positions_pcp_full_np = self.positions_pcp_full_np[: - total_num_scheduled_tokens_pcp_full] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], - arange_pcp_full, - out=positions_pcp_full_np) - token_indices_pcp_full = ( - positions_pcp_full_np + - req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices_pcp_full), - out=self.input_ids_pcp_full. - cpu[:total_num_scheduled_tokens_pcp_full]) - self.query_lens_pcp_full.copy_to_gpu() - self.query_start_loc_pcp_full.copy_to_gpu() - self.input_ids_pcp_full.gpu[:total_num_scheduled_tokens_pcp_full].copy_( - self.input_ids_pcp_full.cpu[:total_num_scheduled_tokens_pcp_full], - non_blocking=True, - ) - self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full - # For mtpx, pre-allocate mtp slot_mapping here - if self.decode_threshold > 2 and not with_prefill: - num_tokens_ori = sum(list(num_scheduled_tokens.values())) - num_tokens_mtp = \ - num_tokens_ori + num_reqs * (self.decode_threshold - 2) - num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size - req_indices_split = np.array_split(req_indices, - cu_num_tokens)[:num_reqs] - positions_split = np.array_split(positions_np, - cu_num_tokens)[:num_reqs] - for req_idx in range(num_reqs): - ori_req_indice = req_indices_split[req_idx] - ori_position = positions_split[req_idx] - req_indices_split[req_idx] = np.append( - ori_req_indice, - np.repeat(ori_req_indice[-1], self.decode_threshold - 2)) - positions_split[req_idx] = np.append( - ori_position, - np.arange(ori_position[-1] + 1, - ori_position[-1] + self.decode_threshold - 1)) - req_indices_mtp = np.concatenate(req_indices_split) - positions_mtp = np.concatenate(positions_split) - self.input_batch.block_table.compute_slot_mapping( - req_indices_mtp, positions_mtp) - mtp_slot_ori = self.input_batch.block_table.block_tables[ - 0].slot_mapping.cpu[:num_tokens_mtp] - unpad_mask = np.repeat(False, num_tokens_mtp_pad) - unpad_mask[::self.pcp_size] = True - mtp_slot_pad = \ - torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32) - mtp_slot_pad[unpad_mask] = mtp_slot_ori - self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) - @contextmanager def _torch_cuda_wrapper(): diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py new file mode 100644 index 00000000000..650c21b40ee --- /dev/null +++ b/vllm_ascend/worker/pcp_utils.py @@ -0,0 +1,654 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/vllm/worker/worker.py +# + +import atexit +import functools +import math +import os +from contextlib import contextmanager, nullcontext +from enum import Enum +from threading import Lock +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union + +import numpy as np +import torch +from vllm.utils.math_utils import cdiv +from vllm.v1.utils import CpuGpuBuffer + +import vllm_ascend.envs as envs_ascend + +class PCPManager: + """ + Manager for Prefill Context Parallelism (PCP) metadata and buffers. + + This manager encapsulates all PCP-related buffers and logic so that the + ModelRunner can access them via `self.pcp_manager`. + """ + + def __init__( + self, + pcp_world_size: int, + pcp_rank: int, + dcp_world_size: int, + dcp_rank: int, + max_buffer_num_tokens: int, + max_num_reqs: int, + device: torch.device, + vllm_config: VllmConfig, + pin_memory: bool = False, + ) -> None: + self.pcp_world_size = pcp_world_size + self.pcp_world_rank = pcp_rank + self.dcp_world_size = dcp_world_size + self.dcp_world_rank = dcp_rank + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + ( + self.speculative_config.num_speculative_tokens + if self.speculative_config else 0) + self.vllm_config = vllm_config + self.max_num_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens + self.max_num_reqs = self.vllm_config.scheduler_config.max_num_seqs + self.device = device + self.pcp_allgather_restore_idx = CpuGpuBuffer( + max_buffer_num_tokens, + dtype=torch.int64, + device=device, + pin_memory=pin_memory, + ) + self.pcp_padded_slot_mapping = torch.full( + (max_buffer_num_tokens, ), fill_value=-1, + dtype=torch.int32, + device=device, + ) + self.num_pcp_pads_cpu_tensor = torch.zeros((max_num_reqs, ), + device="cpu", + dtype=torch.int64) + self.num_pcp_pads_cpu = self.num_pcp_pads_cpu_tensor.numpy() + self.pcp_unpad_mask_cpu_tensor = torch.zeros( + (max_buffer_num_tokens, ), + device="cpu", + dtype=torch.bool, + ) + self.num_actual_tokens_pcp_padded = 0 + self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() + self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ + [] for _ in range(self.pcp_world_size) + ] + self.full_indices = list( + range(self.max_num_tokens * self.pcp_world_size * + self.dcp_world_size + self.pcp_world_size * + self.dcp_world_size * self.max_num_reqs)) + if self.speculative_config and self.pcp_world_size > 1: + self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens, + dtype=torch.int32) + self.query_start_loc_pcp_full = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32) + self.positions_pcp_full = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=True) + self.positions_pcp_full_np = self.positions_pcp_full.numpy() + self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) + + def _get_cumsum_and_arange( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + cumsum_dtype: np.dtype | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Get the cumulative sum and batched arange of the given array. + # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) + """ + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_scheduled_tokens, dtype=cumsum_dtype) + total_num_tokens = cu_num_tokens[-1] + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, + num_scheduled_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = arange_np[:total_num_tokens] - cumsums_offsets + + return cu_num_tokens, arange + + def update_tokens_for_pcp( + self, + num_scheduled_tokens: np.ndarray, + arange_np: np.ndarray, + num_reqs: int, + reorder_batch_threshold: int | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """ + Update token counts and positions for Prefill Context Parallelism (PCP). + + When using Prefill Context Parallelism, each request's prefill sequence is + split across multiple PCP ranks. The splitting strategy used here is the + "DualChunkSwap" style: each request's (padded) sequence is split into + 2 * pcp_world_size chunks and ranks are assigned chunks in an interleaved + head/tail pattern to balance load. + + This function: + - Computes how many tokens each request should be processed by the current + PCP rank (pcp_tokens). + - Computes the flattened positions of those tokens within the local + padded buffer (pcp_positions). + - Updates runner state arrays used to restore original order and mask out + padded tokens after allgather: + - self.num_pcp_pads_cpu: number of pads added per request + - self.pcp_unpad_mask_cpu: boolean mask marking real tokens in the + padded allgather buffer + - self.pcp_allgather_restore_idx: index array used to restore original + ordering after per-rank allgather and interleaving. + + Args: + num_scheduled_tokens: 1D numpy array of length num_reqs containing + the number of new tokens scheduled per request. + arange_np: 1D numpy array of length max_buffer_num_tokens used for + efficient batched arange operations. + num_reqs: Total number of requests in the batch. + reorder_batch_threshold: Threshold for decode vs prefill requests. + + Returns: + Tuple (pcp_tokens, pcp_positions): + - pcp_tokens: number of tokens per request that this PCP rank will + actually process (after splitting / replication). + - pcp_positions: flattened positions for those tokens on this rank, + used to build the positions buffer for the model. + + Example: + >>> Assume tokens = [1, 5, 8], pcp_world_size = 2. After _update_tokens_for_pcp. + >>> pcp_rank = 0 get ([1, 4, 4], [0, 0, 1, 6, 7, 0, 1, 6, 7]) + >>> pcp_rank = 1 get ([1, 4, 4], [0, 2, 3, 4, 5, 2, 3, 4, 5]) + >>> Meanwhile, the following results are same for each pcp rank + >>> self.num_pcp_pads_cpu + [1, 3, 0] + >>> self.pcp_unpad_mask_cpu + [True, False, True, True, True, True, True, False, False, + False, True, True, True, True, True, True, True, True] + >>> self.pcp_allgather_resotre_idx + [0, 9, 1, 2, 10, 11, 12, 13, 3, 4, 5, 6, 14, 15, 16, 17, 7, 8] + """ + + assert reorder_batch_threshold is not None, ( + "PCP depends on reorder batch to split decode and prefill requests." + ) + num_decode_reqs = sum(num_scheduled_tokens <= reorder_batch_threshold) + num_decode_tokens = sum(num_scheduled_tokens[:num_decode_reqs]) + + # DualChunkSwap requires alignment to a multiple of (2 * pcp_world_size). + # We first pad each request's token count up to that multiple. + num_padded_scheduled_tokens = np.ceil( + num_scheduled_tokens / (2 * self.pcp_world_size)).astype( + np.int32) * (2 * self.pcp_world_size) + + # PCP does not split decode requests. For decode requests, we instead + # duplicate the scheduled tokens across the pcp_world_size ranks. + num_padded_scheduled_tokens[:num_decode_reqs] = ( + num_scheduled_tokens[:num_decode_reqs] * self.pcp_world_size) + + # Record how many pads were added per request (padded - original). + self.num_pcp_pads_cpu[:num_reqs] = (num_padded_scheduled_tokens - + num_scheduled_tokens) + + # cu_padded_tokens: cumulative sum of padded token counts, + # pcp_padded_arange: per-request arange flattened for padded tokens. + cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange( + num_padded_scheduled_tokens, arange_np) + # Build the mask that marks which positions in the padded allgather buffer + # correspond to real (unpadded) tokens. + self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( + pcp_padded_arange < np.repeat(num_scheduled_tokens, + num_padded_scheduled_tokens)) + + pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size + + # Compute per-request "chunk sizes" for the head/tail splitting. + # For prefill requests, we further split the pcp_tokens into two chunks + # (head and tail). For decode requests, the chunk equals pcp_tokens. + pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1) + pcp_chunk_sizes[:num_decode_reqs] = pcp_tokens[:num_decode_reqs] + + # Build arange-style helpers for pcp tokens and chunk sizes: + # - pcp_arange gives indices repeated for each token in pcp_tokens + # - pcp_chunk_arange gives indices repeated for each position inside chunks + _, pcp_arange = self._get_cumsum_and_arange(pcp_tokens, arange_np) + _, pcp_chunk_arange = self._get_cumsum_and_arange( + pcp_chunk_sizes, arange_np) + + # Mask that marks whether a position belongs to the head chunk (True) + # or the tail chunk (False). For decode requests, tail chunk won't exist + # and is handled specially below. + pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, + pcp_tokens) + + def get_current_rank_positions(positions_start_loc: int | np.ndarray, + rank: int): + """ + Compute flattened positions for the given rank with a given start + offset for each request (positions_start_loc). + + - For head chunks: start at positions_start_loc + rank * chunk_size. + - For tail chunks: start at positions_start_loc + (2*pcp_world_size- rank - + 1) * chunk_size. + - For decode requests: no tail chunks; their positions are filled from the + contiguous (unpadded) `tokens` arange instead (handled after). + """ + positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32) + head_start_loc = positions_start_loc + rank * pcp_chunk_sizes + tail_start_loc = ( + positions_start_loc + + (2 * self.pcp_world_size - rank - 1) * pcp_chunk_sizes) + # Fill head positions using chunk arange offset by head_start_loc. + positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat( + head_start_loc, pcp_chunk_sizes) + # Fill tail positions. Note decode requests do not have tail chunks, + # so the tail filling is only for prefill positions. + positions[~pcp_head_chunk_mask] = ( + pcp_chunk_arange[num_decode_tokens:] + + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_tokens:]) + return positions + + positions = get_current_rank_positions(0, self.pcp_world_rank) + # Decode tokens are duplicated only after AG. But their positions are + # same without prefill context parallel. + if num_decode_reqs > 0: + positions[:num_decode_tokens] = self._get_cumsum_and_arange( + num_scheduled_tokens[:num_decode_reqs], arange_np)[1] + + # Build the restore index used after allgather. + padded_pos_start_loc = np.roll(cu_padded_tokens, 1) + padded_pos_start_loc[0] = 0 + all_positions_lst = [ + get_current_rank_positions(padded_pos_start_loc, rank_i) + for rank_i in range(self.pcp_world_size) + ] + all_positions = np.concatenate(all_positions_lst) + self.pcp_allgather_restore_idx.np[:all_positions.shape[0]] = ( + all_positions.argsort()) + self.pcp_allgather_restore_idx.copy_to_gpu(all_positions.shape[0]) + + return ( + pcp_tokens[:num_reqs], + positions, + ) + + def get_logits_indices(self, cu_num_tokens: np.ndarray, num_reqs: int): + return (torch.from_numpy(cu_num_tokens) * self.pcp_world_size - + self.num_pcp_pads_cpu_tensor[:num_reqs] - 1) + + def get_discard_request_mask( + self, + num_computed_tokens_cpu: np.ndarray, + num_scheduled_tokens: np.ndarray, + num_reqs: int, + num_tokens_np: np.ndarray, + ): + return (num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens * self.pcp_world_size - + self.num_pcp_pads_cpu[:num_reqs]) < num_tokens_np + + def get_padded_slot_mapping(self, num_tokens: int, + slot_mapping: torch.Tensor): + # After pcp allgather and restore, there are padded tokens in kv, + # so we need pad slotmapping for alignment. + pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:num_tokens * + self. + pcp_world_size] + cp_unpad_mask = self.pcp_unpad_mask_cpu_tensor[:num_tokens * + self.pcp_world_size] + pcp_padded_slot_mapping[cp_unpad_mask] = slot_mapping + return pcp_padded_slot_mapping + + def get_restore_hidden_states( + self, + hidden_states: torch.Tensor, + ): + # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx + # ignores the padding from CUDA Graph. + from vllm.distributed.parallel_state import get_pcp_group + hidden_states = get_pcp_group().all_gather( + hidden_states[:self.num_actual_tokens_pcp_padded // + self.pcp_world_size], + 0, + ) + restore_idx = self.pcp_allgather_restore_idx.gpu[:hidden_states. + shape[0]] + return torch.index_select( + hidden_states, + 0, + restore_idx, + ) + + def generate_pcp_mtp_input( + self, + num_reqs: int, + total_num_scheduled_tokens: int, + num_scheduled_tokens: dict[str, int], + with_prefill: bool = True, + req_indices=None, + positions_np=None, + cu_num_tokens=None, + ): + """ + While pcp > 1, model inputs (input_ids, position, etc.) are split across pcp group, + but mtp need to shift original input_ids before pcp splitting, + so we record original input_ids here. + """ + with_prefill = total_num_scheduled_tokens <= self.decode_threshold * num_reqs + total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens + num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) + for i, req_id in enumerate(self.input_batch.req_ids): + num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] + self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy( + num_scheduled_tokens_pcp_full) + req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens_pcp_full) + cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) + self.query_start_loc_pcp_full.np[0] = 0 + self.query_start_loc_pcp_full.np[1:num_reqs + + 1] = cu_num_tokens_pcp_full + self.query_start_loc_pcp_full.np[num_reqs + 1:].fill(-1) + cumsums_offsets_pcp_full = np.repeat( + cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, + num_scheduled_tokens_pcp_full) + arange_pcp_full = self.arange_np[: + total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full + positions_pcp_full_np = self.positions_pcp_full_np[: + total_num_scheduled_tokens_pcp_full] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], + arange_pcp_full, + out=positions_pcp_full_np) + token_indices_pcp_full = ( + positions_pcp_full_np + + req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices_pcp_full), + out=self.input_ids_pcp_full. + cpu[:total_num_scheduled_tokens_pcp_full]) + self.query_lens_pcp_full.copy_to_gpu() + self.query_start_loc_pcp_full.copy_to_gpu() + self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full) + self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full + # For mtpx, pre-allocate mtp slot_mapping here + if self.decode_threshold > 2 and not with_prefill: + num_tokens_ori = sum(list(num_scheduled_tokens.values())) + num_tokens_mtp = \ + num_tokens_ori + num_reqs * (self.decode_threshold - 2) + num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size + req_indices_split = np.array_split(req_indices, + cu_num_tokens)[:num_reqs] + positions_split = np.array_split(positions_np, + cu_num_tokens)[:num_reqs] + for req_idx in range(num_reqs): + ori_req_indice = req_indices_split[req_idx] + ori_position = positions_split[req_idx] + req_indices_split[req_idx] = np.append( + ori_req_indice, + np.repeat(ori_req_indice[-1], self.decode_threshold - 2)) + positions_split[req_idx] = np.append( + ori_position, + np.arange(ori_position[-1] + 1, + ori_position[-1] + self.decode_threshold - 1)) + req_indices_mtp = np.concatenate(req_indices_split) + positions_mtp = np.concatenate(positions_split) + self.input_batch.block_table.compute_slot_mapping( + req_indices_mtp, positions_mtp) + mtp_slot_ori = self.input_batch.block_table.block_tables[ + 0].slot_mapping.cpu[:num_tokens_mtp] + unpad_mask = np.repeat(False, num_tokens_mtp_pad) + unpad_mask[::self.pcp_size] = True + mtp_slot_pad = \ + torch.full([num_tokens_mtp_pad], -1, dtype=torch.int32) + mtp_slot_pad[unpad_mask] = mtp_slot_ori + self.mtp_slot_pad = mtp_slot_pad.to(self.device, non_blocking=True) + + def _get_cp_local_seq_lens( + self, + seq_lens: torch.Tensor, + pcp_world_size: int = 1, + dcp_world_size: int = 1, + cp_kv_cache_interleave_size: int = 1, + ) -> torch.Tensor: + """While using pcp or dcp, kv_cache size stored on each rank may be different, + use this function to calculate split decode seq_lens of each (p/d)cp rank. + """ + num_requests = seq_lens.size(0) + total_world_size = pcp_world_size * dcp_world_size + seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size) + rank_offsets = (torch.arange(total_world_size, + dtype=torch.int32).unsqueeze(0).repeat( + num_requests, 1)) + base = (seq_lens_tiled // cp_kv_cache_interleave_size // + total_world_size * cp_kv_cache_interleave_size) + remainder = seq_lens_tiled - base * total_world_size + remainder = torch.clip( + remainder - rank_offsets * cp_kv_cache_interleave_size, + 0, + cp_kv_cache_interleave_size, + ) + dcp_local_seq_lens = (base + remainder).reshape( + [-1, pcp_world_size, dcp_world_size]) + return dcp_local_seq_lens + + def generate_kv_idx(self, scheduler_output, input_batch): + if not self.pcp_world_size > 1: + return + self.cp_kv_recover_idx_for_chunk = [[] + for _ in range(self.pcp_world_size) + ] + + for i, req_id in enumerate(input_batch.req_ids): + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + is_prefill = input_batch.num_computed_tokens_cpu[ + i] < input_batch.num_prompt_tokens[i] + if is_prefill: + num_cp_padded_scheduled_tokens = cdiv( + num_scheduled_tokens, + 2 * self.pcp_world_size) * (2 * self.pcp_world_size) + chunk_size = num_cp_padded_scheduled_tokens // ( + 2 * self.pcp_world_size) + num_added_recover_tokens = len( + self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size + for rank in range(self.pcp_world_size): + self.cp_kv_recover_idx_for_chunk[rank].extend( + self.full_indices[rank * chunk_size + + num_added_recover_tokens:(rank + 1) * + chunk_size + + num_added_recover_tokens]) + self.cp_kv_recover_idx_for_chunk[rank].extend( + self.full_indices[num_cp_padded_scheduled_tokens - + (rank + 1) * chunk_size + + num_added_recover_tokens: + num_cp_padded_scheduled_tokens - + rank * chunk_size + + num_added_recover_tokens]) + + cp_kv_recover_idx_for_chunk = torch.from_numpy( + np.concatenate(self.cp_kv_recover_idx_for_chunk)).to( + device=self.device) + cp_kv_recover_idx_for_chunk.copy_(torch.tensor( + np.array( + self.cp_kv_recover_idx_for_chunk).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( + torch.float32).argsort().to(torch.int32) + + def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, + attn_mask, input_batch): + from vllm_ascend.attention.utils import \ + AscendPrefillContextParallelMetadata + num_reqs = input_batch.num_reqs or query_lens.size(0) + num_decodes = sum(input_batch.num_computed_tokens_cpu[:num_reqs] >= + input_batch.num_prompt_tokens[:num_reqs]) + num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size + self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded + long_seq_metadata = None + if self.pcp_world_size * self.dcp_world_size > 1: + decode_context_lens = input_batch.num_tokens[:num_decodes] + prefill_context_lens = input_batch.num_computed_tokens_cpu[ + num_decodes:num_reqs] + context_lens = np.concatenate( + [decode_context_lens, prefill_context_lens]) + num_computed_tokens_of_pcp_dcp = torch.zeros( + [ + num_reqs * self.decode_threshold, self.pcp_world_size, + self.dcp_world_size + ], + dtype=torch.int32, + ) + # For pcp + spec decode, we flatten seq_lens + # to avoid irregular spec_attn_mask shape + for decode_idx in range(self.decode_threshold): + num_computed_tokens_of_pcp_dcp[ + self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ + self._get_cp_local_seq_lens( + torch.tensor(context_lens), + self.pcp_world_size, + self.dcp_world_size, + self.vllm_config.parallel_config.cp_kv_cache_interleave_size, + ) + long_seq_metadata = AscendPrefillContextParallelMetadata( + num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, + num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. + numpy()) + if self.pcp_world_size > 1: + q_head_idx, q_tail_idx = [], [] + kv_with_q_head_nomask_idx, kv_with_q_head_mask_idx = [], [] + kv_with_q_tail_nomask_idx, kv_with_q_tail_mask_idx = [], [] + chunk_seqlens = [] + kv_with_q_head_nomask_seqlens, kv_with_q_tail_nomask_seqlens = [], [] + q_req_offset = 0 + kv_req_offset = 0 + q_head_chunk_id = self.pcp_world_rank + q_tail_chunk_id = self.pcp_world_size * 2 - 1 - self.pcp_world_rank + for i, seq_len in enumerate(query_lens): + if i < num_decodes: + continue + chunk_len = seq_len // 2 + chunk_seqlens.append(chunk_len) + q_head_idx.extend( + list(range(q_req_offset, q_req_offset + chunk_len))) + kv_with_q_head_nomask_idx.extend( + list( + range(kv_req_offset, kv_req_offset + + chunk_len * q_head_chunk_id))) + kv_with_q_head_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_head_chunk_id, + kv_req_offset + chunk_len * + (q_head_chunk_id + 1)))) + kv_with_q_head_nomask_seqlens.append(chunk_len * + q_head_chunk_id) + + q_tail_idx.extend( + list( + range(q_req_offset + chunk_len, + q_req_offset + chunk_len * 2))) + kv_with_q_tail_nomask_idx.extend( + list( + range(kv_req_offset, kv_req_offset + + chunk_len * q_tail_chunk_id))) + kv_with_q_tail_mask_idx.extend( + list( + range( + kv_req_offset + chunk_len * q_tail_chunk_id, + kv_req_offset + chunk_len * + (q_tail_chunk_id + 1)))) + kv_with_q_tail_nomask_seqlens.append(chunk_len * + q_tail_chunk_id) + + q_req_offset += seq_len + kv_req_offset += seq_len * self.pcp_world_size + + # Convert lists to tensors and move to device + def _list_to_tensor(lst, device, dtype=torch.int32): + tensor_npu = torch.zeros(len(lst), + dtype=dtype, + device=device) + tensor_npu.copy_(torch.tensor(lst, dtype=dtype), + non_blocking=True) + return tensor_npu + + q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device) + q_tail_idx_tensor = _list_to_tensor(q_tail_idx, self.device) + self.q_head_idx_tensor = q_head_idx_tensor + self.q_tail_idx_tensor = q_tail_idx_tensor + + q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor]) + q_full_idx = q_full_idx.to(torch.float32).argsort().to( + torch.int32) + self.q_full_idx = q_full_idx + + self.kv_idx_names = { + 'kv_with_q_head_nomask_idx_tensor': + kv_with_q_head_nomask_idx, + 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx, + 'kv_with_q_tail_nomask_idx_tensor': + kv_with_q_tail_nomask_idx, + 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx + } + for key, value in self.kv_idx_names.items(): + tensor_npu = _list_to_tensor(value, self.device) + self.kv_idx_names[key] = tensor_npu + + attn_mask_seqlens = torch.tensor( + [chunk_seqlens, chunk_seqlens], dtype=torch.int32) + head_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_head_nomask_seqlens], + dtype=torch.int32) + tail_attn_nomask_seqlens = torch.tensor( + [chunk_seqlens, kv_with_q_tail_nomask_seqlens], + dtype=torch.int32) + pcp_prefill_mask = attn_mask + + self.extra_long_seq_kwargs = { + 'attn_mask_seqlens': attn_mask_seqlens, + 'head_attn_nomask_seqlens': head_attn_nomask_seqlens, + 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens, + 'pcp_prefill_mask': pcp_prefill_mask + } + long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: + num_actual_tokens_pcp_padded] + long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk + long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor + long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor + long_seq_metadata.q_full_idx = self.q_full_idx + long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_head_nomask_idx_tensor'] + long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_head_mask_idx_tensor'] + long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_tail_nomask_idx_tensor'] + long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[ + 'kv_with_q_tail_mask_idx_tensor'] + long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[ + 'attn_mask_seqlens'] + long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[ + 'head_attn_nomask_seqlens'] + long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[ + 'tail_attn_nomask_seqlens'] + long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ + 'pcp_prefill_mask'] + self.long_seq_metadata = long_seq_metadata + return long_seq_metadata From c79baeffa6c676987c1fa7174632ac6e2b454f09 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 20:05:38 +0800 Subject: [PATCH 17/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 8daa964febc..4ef58970a1a 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1100,4 +1100,4 @@ def dispose_layer(layer: Any): def replace_layer(original_layer: Any, new_layer: Any): original_layer.__class__ = new_layer.__class__ - original_layer.__dict__ = new_layer.__dict__ \ No newline at end of file + original_layer.__dict__ = new_layer.__dict__ From 040be70a9e6bce092317ecaeed153f5550e3404e Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 20:16:46 +0800 Subject: [PATCH 18/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/pcp_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 650c21b40ee..af99fbbc022 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -352,7 +352,6 @@ def generate_pcp_mtp_input( but mtp need to shift original input_ids before pcp splitting, so we record original input_ids here. """ - with_prefill = total_num_scheduled_tokens <= self.decode_threshold * num_reqs total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) for i, req_id in enumerate(self.input_batch.req_ids): From bb9174da2295496324ebcef3217297e4eeb42b67 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 20:51:37 +0800 Subject: [PATCH 19/43] cleancode Signed-off-by: zhenwenqi2024 --- .pre-commit-config.yaml | 8 +++--- tests/ut/worker/test_pcp_manager.py | 40 ++++++++++++++------------- vllm_ascend/worker/model_runner_v1.py | 3 +- vllm_ascend/worker/pcp_utils.py | 36 +++++++++++------------- 4 files changed, 43 insertions(+), 44 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 82bde178ec0..b16317c7466 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,10 +55,10 @@ repos: hooks: - id: pymarkdown args: [fix] -- repo: https://github.com/rhysd/actionlint - rev: v1.7.7 - hooks: - - id: actionlint +# - repo: https://github.com/rhysd/actionlint +# rev: v1.7.7 +# hooks: +# - id: actionlint - repo: local hooks: # For local development, you can run mypy using tools/mypy.sh script if needed. diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index 9ea5220d749..f1ea4dd911e 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -37,9 +37,8 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, vllm_config.model_config = MagicMock() vllm_config.model_config.use_mla = use_mla vllm_config.parallel_config.cp_kv_cache_interleave_size = 64 - vllm_config.speculative_config.num_speculative_tokens=0 + vllm_config.speculative_config.num_speculative_tokens = 0 - pcp_manager = PCPManager(pcp_world_size=pcp_size, pcp_rank=0, dcp_world_size=dcp_size, @@ -66,13 +65,13 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, num_prompt_tokens.append(query_lens[i]) num_tokens.append(query_lens[i]) - input_batch.num_computed_tokens_cpu = torch.tensor( - num_computed_tokens) + input_batch.num_computed_tokens_cpu = torch.tensor(num_computed_tokens) input_batch.num_prompt_tokens = torch.tensor(num_prompt_tokens) input_batch.num_tokens = torch.tensor(num_tokens) query_lens = torch.tensor(query_lens) - result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None, input_batch) + result = pcp_manager.generate_pcp_metadata(total_tokens, query_lens, None, + input_batch) if not expect_not_none: assert result is None, f"Expected to return None, but got {type(result)}" @@ -113,6 +112,7 @@ def test_generate_pcp_metadata_basic(pcp_size, dcp_size, num_reqs, query_lens, assert result.pcp_prefill_mask.shape == (2048, 2048) + @pytest.mark.parametrize( "tokens, num_reqs, num_computed_tokens, num_prompt_tokens, pcp_size, pcp_rank, expected_pcp_tokens", [ @@ -133,9 +133,8 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, expected_pcp_tokens): vllm_config = MagicMock() vllm_config.model_config = MagicMock() - vllm_config.speculative_config.num_speculative_tokens=0 + vllm_config.speculative_config.num_speculative_tokens = 0 - pcp_manager = PCPManager(pcp_world_size=pcp_size, pcp_rank=0, dcp_world_size=1, @@ -147,12 +146,12 @@ def test_update_tokens_for_pcp_basic(tokens, num_reqs, num_computed_tokens, pin_memory=False) input_batch = MagicMock() input_batch.num_reqs = num_reqs - input_batch.num_computed_tokens_cpu = np.array( - num_computed_tokens, dtype=np.int32) - input_batch.num_prompt_tokens = np.array(num_prompt_tokens, - dtype=np.int32) + input_batch.num_computed_tokens_cpu = np.array(num_computed_tokens, + dtype=np.int32) + input_batch.num_prompt_tokens = np.array(num_prompt_tokens, dtype=np.int32) arange_np = np.arange(10000) - pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp(np.array(tokens), arange_np, num_reqs, 1) + pcp_tokens_result, positions_result = pcp_manager.update_tokens_for_pcp( + np.array(tokens), arange_np, num_reqs, 1) assert np.array_equal(pcp_tokens_result, expected_pcp_tokens), \ f"Expected pcp_tokens: {expected_pcp_tokens}, got: {pcp_tokens_result}" @@ -200,7 +199,7 @@ def test_get_cp_local_seq_lens( ): vllm_config = MagicMock() vllm_config.model_config = MagicMock() - vllm_config.speculative_config.num_speculative_tokens=0 + vllm_config.speculative_config.num_speculative_tokens = 0 pcp_manager = PCPManager(pcp_world_size=pcp_world_size, pcp_rank=0, dcp_world_size=dcp_world_size, @@ -210,11 +209,12 @@ def test_get_cp_local_seq_lens( device="cpu", vllm_config=vllm_config, pin_memory=False) - ret = pcp_manager._get_cp_local_seq_lens(seq_lens, - pcp_world_size, dcp_world_size, - cp_kv_cache_interleave_size) + ret = pcp_manager._get_cp_local_seq_lens(seq_lens, pcp_world_size, + dcp_world_size, + cp_kv_cache_interleave_size) assert torch.equal(ret, target) + # yapf: disable @pytest.mark.parametrize( "req_ids, num_computed_tokens," \ @@ -280,7 +280,7 @@ def test_generate_pcp_mtp_input( max_num_tokens = 4096 vllm_config = MagicMock() vllm_config.model_config = MagicMock() - vllm_config.speculative_config.num_speculative_tokens=1 + vllm_config.speculative_config.num_speculative_tokens = 1 vllm_config.scheduler_config.max_num_seqs = max_num_reqs vllm_config.scheduler_config.max_num_batched_tokens = max_model_len pcp_manager = PCPManager(pcp_world_size=2, @@ -308,11 +308,13 @@ def test_generate_pcp_mtp_input( # Set input_batch input_batch.req_ids = req_ids input_batch.num_computed_tokens_cpu[:num_computed_tokens. - size] = num_computed_tokens + size] = num_computed_tokens for i, token_ids_tensor in enumerate(token_ids_tensor_list): token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor - pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens, num_scheduled_tokens, input_batch, arange_np) + pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens, + num_scheduled_tokens, input_batch, + arange_np) assert torch.equal( pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], target_input_ids_pcp_full) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 94ccacb0a84..ee7c145d01e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -550,7 +550,8 @@ def _prepare_inputs( self.pcp_manager.generate_pcp_mtp_input( num_reqs, total_num_scheduled_tokens, scheduler_output.num_scheduled_tokens, with_prefill, - req_indices, positions_np, cu_num_tokens) + self.input_batch, self.arange_np, req_indices, positions_np, + cu_num_tokens) if self.pcp_size > 1: if not self.vllm_config.model_config.use_mla: diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index af99fbbc022..7fde93a783e 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -17,21 +17,14 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # -import atexit -import functools -import math -import os -from contextlib import contextmanager, nullcontext -from enum import Enum -from threading import Lock -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union +from typing import List import numpy as np import torch +from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer -import vllm_ascend.envs as envs_ascend class PCPManager: """ @@ -72,7 +65,8 @@ def __init__( pin_memory=pin_memory, ) self.pcp_padded_slot_mapping = torch.full( - (max_buffer_num_tokens, ), fill_value=-1, + (max_buffer_num_tokens, ), + fill_value=-1, dtype=torch.int32, device=device, ) @@ -343,6 +337,8 @@ def generate_pcp_mtp_input( total_num_scheduled_tokens: int, num_scheduled_tokens: dict[str, int], with_prefill: bool = True, + input_batch=None, + arange_np=None, req_indices=None, positions_np=None, cu_num_tokens=None, @@ -354,11 +350,11 @@ def generate_pcp_mtp_input( """ total_num_scheduled_tokens_pcp_full = total_num_scheduled_tokens num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32) - for i, req_id in enumerate(self.input_batch.req_ids): + for i, req_id in enumerate(input_batch.req_ids): num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id] self.query_lens_pcp_full.cpu[:num_reqs] = torch.from_numpy( num_scheduled_tokens_pcp_full) - req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs], + req_indices_pcp_full = np.repeat(arange_np[:num_reqs], num_scheduled_tokens_pcp_full) cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full) self.query_start_loc_pcp_full.np[0] = 0 @@ -368,24 +364,24 @@ def generate_pcp_mtp_input( cumsums_offsets_pcp_full = np.repeat( cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full, num_scheduled_tokens_pcp_full) - arange_pcp_full = self.arange_np[: - total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full + arange_pcp_full = arange_np[:total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full positions_pcp_full_np = self.positions_pcp_full_np[: total_num_scheduled_tokens_pcp_full] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full], + np.add(input_batch.num_computed_tokens_cpu[req_indices_pcp_full], arange_pcp_full, out=positions_pcp_full_np) token_indices_pcp_full = ( positions_pcp_full_np + - req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]) - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + req_indices_pcp_full * input_batch.token_ids_cpu.shape[1]) + torch.index_select(input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices_pcp_full), out=self.input_ids_pcp_full. cpu[:total_num_scheduled_tokens_pcp_full]) self.query_lens_pcp_full.copy_to_gpu() self.query_start_loc_pcp_full.copy_to_gpu() - self.input_ids_pcp_full.copy_to_gpu(total_num_scheduled_tokens_pcp_full) + self.input_ids_pcp_full.copy_to_gpu( + total_num_scheduled_tokens_pcp_full) self.cu_num_tokens_pcp_full = cu_num_tokens_pcp_full # For mtpx, pre-allocate mtp slot_mapping here if self.decode_threshold > 2 and not with_prefill: @@ -409,9 +405,9 @@ def generate_pcp_mtp_input( ori_position[-1] + self.decode_threshold - 1)) req_indices_mtp = np.concatenate(req_indices_split) positions_mtp = np.concatenate(positions_split) - self.input_batch.block_table.compute_slot_mapping( + input_batch.block_table.compute_slot_mapping( req_indices_mtp, positions_mtp) - mtp_slot_ori = self.input_batch.block_table.block_tables[ + mtp_slot_ori = input_batch.block_table.block_tables[ 0].slot_mapping.cpu[:num_tokens_mtp] unpad_mask = np.repeat(False, num_tokens_mtp_pad) unpad_mask[::self.pcp_size] = True From 48795d3cdf67ff27d0630118c0db6f6c601ffb06 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 20:52:06 +0800 Subject: [PATCH 20/43] cleancode Signed-off-by: zhenwenqi2024 --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b16317c7466..82bde178ec0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,10 +55,10 @@ repos: hooks: - id: pymarkdown args: [fix] -# - repo: https://github.com/rhysd/actionlint -# rev: v1.7.7 -# hooks: -# - id: actionlint +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint - repo: local hooks: # For local development, you can run mypy using tools/mypy.sh script if needed. From d457ed4e1d5b8b6dfed564d3029ad7841ae447f7 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 21:14:16 +0800 Subject: [PATCH 21/43] cleancode Signed-off-by: zhenwenqi2024 --- tests/ut/worker/test_pcp_manager.py | 2 +- vllm_ascend/worker/pcp_utils.py | 18 ++++++++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index f1ea4dd911e..baabc99475d 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -313,7 +313,7 @@ def test_generate_pcp_mtp_input( token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens, - num_scheduled_tokens, input_batch, + num_scheduled_tokens, False, input_batch, arange_np) assert torch.equal( pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 7fde93a783e..ab6e577c4be 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -89,17 +89,23 @@ def __init__( self.dcp_world_size + self.pcp_world_size * self.dcp_world_size * self.max_num_reqs)) if self.speculative_config and self.pcp_world_size > 1: - self.input_ids_pcp_full = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.query_start_loc_pcp_full = self._make_buffer( - self.max_num_reqs + 1, dtype=torch.int32) + self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens, + dtype=torch.int32, + device=device, + pin_memory=pin_memory) + self.query_start_loc_pcp_full = CpuGpuBuffer(self.max_num_reqs + 1, + dtype=torch.int32, + device=device, + pin_memory=pin_memory) self.positions_pcp_full = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=True) self.positions_pcp_full_np = self.positions_pcp_full.numpy() - self.query_lens_pcp_full = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) + self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=True) def _get_cumsum_and_arange( self, From bb16241892d1fb6cdb685ba1a1d4d5fb91883d69 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 21:14:37 +0800 Subject: [PATCH 22/43] cleancode Signed-off-by: zhenwenqi2024 --- tests/ut/worker/test_pcp_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ut/worker/test_pcp_manager.py b/tests/ut/worker/test_pcp_manager.py index baabc99475d..9f5863abe1c 100644 --- a/tests/ut/worker/test_pcp_manager.py +++ b/tests/ut/worker/test_pcp_manager.py @@ -313,8 +313,8 @@ def test_generate_pcp_mtp_input( token_ids_cpu_tensor[i][:token_ids_tensor.size(0)] = token_ids_tensor pcp_manager.generate_pcp_mtp_input(num_reqs, total_num_scheduled_tokens, - num_scheduled_tokens, False, input_batch, - arange_np) + num_scheduled_tokens, False, + input_batch, arange_np) assert torch.equal( pcp_manager.input_ids_pcp_full.cpu[:total_num_scheduled_tokens], target_input_ids_pcp_full) From b2c15af6eeba7e14f5b2230fa35b2cbe747e56c3 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 22:52:10 +0800 Subject: [PATCH 23/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/spec_decode/mtp_proposer.py | 6 +++--- vllm_ascend/worker/model_runner_v1.py | 12 +++++++++--- vllm_ascend/worker/pcp_utils.py | 4 ++-- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index f17bf5eaa90..a0ec9d38e98 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -381,9 +381,9 @@ def generate_token_ids(self, req_scheduled_tokens = scheduler_output.num_scheduled_tokens if self.pcp_size * self.dcp_size > 1: long_seq_metadata = self.runner.long_seq_metadata - input_ids_pcp_full = self.runner.input_ids_pcp_full.gpu - query_start_loc_pcp_full = self.runner.query_start_loc_pcp_full.gpu - query_start_loc_pcp_full_cpu = self.runner.query_start_loc_pcp_full.cpu + input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu + query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu + query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu num_reqs = self.runner.input_batch.num_reqs ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ query_start_loc_pcp_full_cpu[:num_reqs] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ee7c145d01e..9fa5b7e83c1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -985,12 +985,13 @@ def _prepare_inputs( # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens = self.pcp_manager.query_start_loc_pcp_full_cpu[1:num_reqs+1] - \ - self.pcp_manager.query_start_loc_pcp_full_cpu[:num_reqs] + ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[:num_reqs] + ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[:num_reqs] num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() num_decode_reqs = num_reqs - num_prefill_reqs - num_decode_reqs_flatten = num_decode_reqs * self.decode_threshold + num_decode_reqs_flatten = \ + ori_query_lens_cpu[:num_decode_reqs].sum().item() blk_table_tensor[ num_decode_reqs_flatten:num_decode_reqs_flatten + num_prefill_reqs].copy_( @@ -1001,6 +1002,11 @@ def _prepare_inputs( ori_query_lens[:num_decode_reqs], dim=0)) common_attn_metadata.block_table_tensor = \ blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] + if 'pad_size' in locals() and pad_size > 0: + ori_query_lens_cpu[-pad_size:] = \ + torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) + self.long_seq_metadata.max_query_len_pcp_full = \ + ori_query_lens_cpu.max().item() if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index ab6e577c4be..c52575f3c93 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -104,7 +104,7 @@ def __init__( self.positions_pcp_full_np = self.positions_pcp_full.numpy() self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs, dtype=torch.int32, - device="cpu", + device=device, pin_memory=True) def _get_cumsum_and_arange( @@ -394,7 +394,7 @@ def generate_pcp_mtp_input( num_tokens_ori = sum(list(num_scheduled_tokens.values())) num_tokens_mtp = \ num_tokens_ori + num_reqs * (self.decode_threshold - 2) - num_tokens_mtp_pad = num_tokens_mtp * self.pcp_size + num_tokens_mtp_pad = num_tokens_mtp * self.pcp_world_size req_indices_split = np.array_split(req_indices, cu_num_tokens)[:num_reqs] positions_split = np.array_split(positions_np, From 25c8cbf457377d49ebfe93f8acacc6e009e6e38e Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 22 Dec 2025 22:53:34 +0800 Subject: [PATCH 24/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9fa5b7e83c1..48dbb6a6dc4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -985,8 +985,10 @@ def _prepare_inputs( # (num_reqs_d + num_reqs_p, max_num_blocks), # flattened block_table: [d0, d0, d1, d1, p0, p1, p2] # (num_reqs_d * decode_threshold + num_reqs_p, max_num_blocks), - ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[:num_reqs] - ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[:num_reqs] + ori_query_lens_cpu = self.pcp_manager.query_lens_pcp_full.cpu[: + num_reqs] + ori_query_lens = self.pcp_manager.query_lens_pcp_full.gpu[: + num_reqs] num_prefill_reqs = (ori_query_lens > self.decode_threshold).sum().item() num_decode_reqs = num_reqs - num_prefill_reqs From a6f610874ccb97d4a2b285683c32d7d519517189 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Tue, 23 Dec 2025 09:39:01 +0800 Subject: [PATCH 25/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b3fa3091b40..0897b167ef4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -354,6 +354,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: IntermediateTensors | None = None self.reorder_batch_threshold: int | None = None + self.long_seq_metadata = None def _init_device_properties(self) -> None: self.num_sms = None From 1c3a2d69c220a2f9ebed6eb90fe05adf487c5b28 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Tue, 23 Dec 2025 12:49:44 +0800 Subject: [PATCH 26/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 1 + vllm_ascend/worker/pcp_utils.py | 35 ++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 20c4a0193e8..0d1badf2970 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -985,6 +985,7 @@ def _prepare_inputs( max_query_len=max_num_scheduled_tokens, decode_token_per_req=self.decode_token_per_req, prefill_context_parallel_metadata=self.long_seq_metadata, + max_seq_len=0, ) if self.speculative_config and self.pcp_size * self.dcp_size > 1: diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 61ec3d6b90f..8d7f27f5bbd 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -499,8 +499,10 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, from vllm_ascend.attention.utils import \ AscendPrefillContextParallelMetadata num_reqs = input_batch.num_reqs or query_lens.size(0) - num_decodes = sum(input_batch.num_computed_tokens_cpu[:num_reqs] >= - input_batch.num_prompt_tokens[:num_reqs]) + query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \ + if self.pcp_world_size > 1 and self.speculative_config else query_lens + num_decodes = (query_lens <= self.decode_threshold).sum().item() + num_prefills = num_reqs - num_decodes num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded long_seq_metadata = None @@ -518,16 +520,41 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, dtype=torch.int32, ) # For pcp + spec decode, we flatten seq_lens - # to avoid irregular spec_attn_mask shape + # to avoid irregular spec_attn_mask shape. + # Same as block_table, we flatten decode seq_lens to query_lens, + # and keep prefill seq_lens unchanged. for decode_idx in range(self.decode_threshold): num_computed_tokens_of_pcp_dcp[ self.decode_threshold - 1 - decode_idx::self.decode_threshold] = \ self._get_cp_local_seq_lens( - torch.tensor(context_lens), + torch.tensor(context_lens) - decode_idx, self.pcp_world_size, self.dcp_world_size, self.vllm_config.parallel_config.cp_kv_cache_interleave_size, ) + if self.decode_threshold > 1: + num_computed_tokens_of_pcp_dcp_list = [] + if num_decodes: + num_decodes_flatten = \ + query_lens[:num_decodes].sum().item() + if query_lens[:num_decodes].min().item( + ) == self.decode_threshold: + decode_flatten_idx = list(range(num_decodes_flatten)) + else: + decode_flatten_idx = [] + for req_id in range(num_decodes): + offset = (req_id + 1) * self.decode_threshold + decode_flatten_idx += \ + list(range(offset - query_lens[req_id], offset)) + num_computed_tokens_of_pcp_dcp_list.append( + num_computed_tokens_of_pcp_dcp[decode_flatten_idx]) + if num_prefills: + num_computed_tokens_of_pcp_dcp_list.append( + num_computed_tokens_of_pcp_dcp[ + (num_decodes + 1) * self.decode_threshold - + 1::self.decode_threshold]) + num_computed_tokens_of_pcp_dcp = torch.cat( + num_computed_tokens_of_pcp_dcp_list, dim=0) long_seq_metadata = AscendPrefillContextParallelMetadata( num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, num_computed_tokens_of_pcp_dcp=num_computed_tokens_of_pcp_dcp. From f0b64840c4d1a7b3c729e34dfe983d03d411f757 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Tue, 23 Dec 2025 15:47:45 +0800 Subject: [PATCH 27/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/spec_decode/mtp_proposer.py | 4 ++-- vllm_ascend/worker/pcp_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 6a1e022cf8e..772ba7ba434 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -893,13 +893,13 @@ def _propose( # (_generate_pcp_mtp_input), and use updated slot_indices # to get corresponding slot_mapping in each step. num_reject_tokens = torch.tensor( - self.runner.cu_num_tokens_pcp_full, + self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to( self.device) - ori_last_token_indices - 1 num_accept_tokens = \ query_lens_d.to(self.device) - num_reject_tokens ori_seq_len = attn_metadata_i.seq_lens - mtp_slot_mapping = self.runner.mtp_slot_pad + mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad # slot_mapping index base offset: # scheduled tokens + pre-allocated mtp tokens + accepted tokens diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 8d7f27f5bbd..5b32569f2e2 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -499,9 +499,9 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, from vllm_ascend.attention.utils import \ AscendPrefillContextParallelMetadata num_reqs = input_batch.num_reqs or query_lens.size(0) - query_lens = self.query_lens_pcp_full.cpu[:num_reqs] \ + query_lens_new = self.query_lens_pcp_full.cpu[:num_reqs] \ if self.pcp_world_size > 1 and self.speculative_config else query_lens - num_decodes = (query_lens <= self.decode_threshold).sum().item() + num_decodes = (query_lens_new <= self.decode_threshold).sum().item() num_prefills = num_reqs - num_decodes num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_world_size self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded From 8bcdbc3fc757cd188ec16f19012e17e6ea34ab79 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Wed, 24 Dec 2025 13:29:15 +0800 Subject: [PATCH 28/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 8 +++++--- vllm_ascend/worker/pcp_utils.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1709ba7607c..3e7213be097 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -563,7 +563,6 @@ def _prepare_inputs( ) # Re-update after PCP split sequences. total_num_scheduled_tokens = sum(num_scheduled_tokens) - scheduler_output.total_num_scheduled_tokens = total_num_scheduled_tokens req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) cu_num_tokens, _ = self._get_cumsum_and_arange( @@ -1015,13 +1014,16 @@ def _prepare_inputs( ori_query_lens[:num_decode_reqs], dim=0)) common_attn_metadata.block_table_tensor = \ blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] + assert self.long_seq_metadata is not None + self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu + if 'pad_size' in locals() and pad_size > 0: ori_query_lens_cpu[-pad_size:] = \ torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) - assert self.long_seq_metadata is not None self.long_seq_metadata.max_query_len_pcp_full = \ ori_query_lens_cpu.max().item() - self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu + + if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 5b32569f2e2..6e660dc87fd 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -217,7 +217,10 @@ def update_tokens_for_pcp( self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( pcp_padded_arange < np.repeat(num_scheduled_tokens, num_padded_scheduled_tokens)) - + unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens * self.pcp_world_size] + unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_world_size]) + unpad_mask_decode[:, 0] = True + unpad_mask_decode[:, 1:] = False pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size # Compute per-request "chunk sizes" for the head/tail splitting. From f009edb03d5d4cf1e27eca1495488b6b3738c078 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Wed, 24 Dec 2025 14:46:36 +0800 Subject: [PATCH 29/43] cleancode Signed-off-by: zhenwenqi2024 --- .pre-commit-config.yaml | 8 ++++---- vllm_ascend/spec_decode/mtp_proposer.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 6 +++--- vllm_ascend/worker/pcp_utils.py | 8 +++++--- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 82bde178ec0..b16317c7466 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,10 +55,10 @@ repos: hooks: - id: pymarkdown args: [fix] -- repo: https://github.com/rhysd/actionlint - rev: v1.7.7 - hooks: - - id: actionlint +# - repo: https://github.com/rhysd/actionlint +# rev: v1.7.7 +# hooks: +# - id: actionlint - repo: local hooks: # For local development, you can run mypy using tools/mypy.sh script if needed. diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index d002dfea8a7..1d970c7583d 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -272,7 +272,7 @@ def dummy_run(self, if self.pcp_size * self.dcp_size > 1: # update long_seq related params and flatten block_table common_attn_metadata.prefill_context_parallel_metadata = \ - self.runner.long_seq_metadata + self.runner.pcp_manager.long_seq_metadata common_attn_metadata.block_table_tensor = \ self.runner.input_batch.block_table[0].get_device_tensor()[ :num_reqs * self.decode_threshold] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3e7213be097..7acb3f1712f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1016,14 +1016,14 @@ def _prepare_inputs( blk_table_tensor[:num_decode_reqs_flatten + num_prefill_reqs] assert self.long_seq_metadata is not None self.long_seq_metadata.query_lens_pcp_full_cpu = ori_query_lens_cpu - + if 'pad_size' in locals() and pad_size > 0: ori_query_lens_cpu[-pad_size:] = \ torch.full([pad_size], ori_query_lens_cpu[-pad_size - 1].item()) self.long_seq_metadata.max_query_len_pcp_full = \ ori_query_lens_cpu.max().item() - - + + if self.speculative_config and \ self.spec_decode_common_attn_metadata is None: diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 6e660dc87fd..26c1c551240 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -217,8 +217,10 @@ def update_tokens_for_pcp( self.pcp_unpad_mask_cpu[:pcp_padded_arange.shape[0]] = ( pcp_padded_arange < np.repeat(num_scheduled_tokens, num_padded_scheduled_tokens)) - unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens * self.pcp_world_size] - unpad_mask_decode = unpad_mask_decode.reshape([-1, self.pcp_world_size]) + unpad_mask_decode = self.pcp_unpad_mask_cpu[:num_decode_tokens * + self.pcp_world_size] + unpad_mask_decode = unpad_mask_decode.reshape( + [-1, self.pcp_world_size]) unpad_mask_decode[:, 0] = True unpad_mask_decode[:, 1:] = False pcp_tokens = num_padded_scheduled_tokens // self.pcp_world_size @@ -681,5 +683,5 @@ def _list_to_tensor(lst, device, dtype=torch.int32): 'tail_attn_nomask_seqlens'] long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[ 'pcp_prefill_mask'] - self.long_seq_metadata = long_seq_metadata + self.long_seq_metadata = long_seq_metadata return long_seq_metadata From e07b34c73689c724b02e37d58a763186dc3d3cd4 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Wed, 24 Dec 2025 15:26:48 +0800 Subject: [PATCH 30/43] cleancode Signed-off-by: zhenwenqi2024 --- tests/ut/spec_decode/test_mtp_proposer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 9e9eb295282..a943cb8133a 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -217,21 +217,22 @@ def test_generate_token_ids(self, mock_cpu_gpu_buffer): mock_deps.runner.spec_decode_common_attn_metadata = MagicMock() mock_deps.runner.pcp_size = 2 mock_deps.runner.dcp_size = 1 - mock_deps.runner.input_ids_pcp_full = CpuGpuBuffer( + mock_deps.runner.pcp_manager = MagicMock() + mock_deps.runner.pcp_manager.input_ids_pcp_full = CpuGpuBuffer( 32, dtype=torch.int32, pin_memory=False, device='cpu', ) - mock_deps.runner.input_ids_pcp_full.cpu = \ + mock_deps.runner.pcp_manager.input_ids_pcp_full.cpu = \ torch.arange(32, dtype=torch.int32) - mock_deps.runner.query_start_loc_pcp_full = CpuGpuBuffer( + mock_deps.runner.pcp_manager.query_start_loc_pcp_full = CpuGpuBuffer( 5, dtype=torch.int32, pin_memory=False, device='cpu', ) - mock_deps.runner.query_start_loc_pcp_full.cpu = \ + mock_deps.runner.pcp_manager.query_start_loc_pcp_full.cpu = \ torch.tensor([0, 8, 16, 24, 32]) mock_deps.positions = torch.arange(16, dtype=torch.int32) mock_deps.hidden_states = torch.zeros(16, 4096, dtype=torch.float16) From b9214c6ace4f8e2f61e54866f025aa6734973fad Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Wed, 24 Dec 2025 15:29:05 +0800 Subject: [PATCH 31/43] cleancode Signed-off-by: zhenwenqi2024 --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b16317c7466..82bde178ec0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,10 +55,10 @@ repos: hooks: - id: pymarkdown args: [fix] -# - repo: https://github.com/rhysd/actionlint -# rev: v1.7.7 -# hooks: -# - id: actionlint +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint - repo: local hooks: # For local development, you can run mypy using tools/mypy.sh script if needed. From d08b5dee0b4811ef21b0fc72b4bef95452f62c1a Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Wed, 24 Dec 2025 21:15:46 +0800 Subject: [PATCH 32/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/spec_decode/mtp_proposer.py | 4 ++-- vllm_ascend/worker/model_runner_v1.py | 2 +- vllm_ascend/worker/pcp_utils.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 1d970c7583d..e09d5b7e177 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -397,7 +397,7 @@ def generate_token_ids(self, # update pcp related params if self.pcp_size > 1: token_indices_to_sample = \ - query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1 + query_start_loc_pcp_full[1:num_reqs + 1] - 1 target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens] target_hidden_states = hidden_states @@ -985,7 +985,7 @@ def _propose( self.hidden_states[:hidden_states.shape[0]] = hidden_states if self.pcp_size * self.dcp_size > 1: # update local seq_len and batch_seq_mask - num_computed_tokens_of_pcp_dcp = self.runner._get_cp_local_seq_lens( + num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( ori_seq_len + step + 1, self.pcp_size, self.dcp_size, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7acb3f1712f..b4a32645b83 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1133,7 +1133,7 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens, # sharing kv across layers need to read the kvcache, # directly return chunked prefill in this scenario return AscendAttentionState.ChunkedPrefill - if np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens): + if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 26c1c551240..20cd90c47b6 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -100,12 +100,12 @@ def __init__( self.positions_pcp_full = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", - pin_memory=True) + pin_memory=pin_memory) self.positions_pcp_full_np = self.positions_pcp_full.numpy() self.query_lens_pcp_full = CpuGpuBuffer(self.max_num_reqs, dtype=torch.int32, device=device, - pin_memory=True) + pin_memory=pin_memory) def _get_cumsum_and_arange( self, From bb82d79b3535df93c07b88265d396a003bf1b17f Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Thu, 25 Dec 2025 23:39:28 +0800 Subject: [PATCH 33/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6f4b13a84c1..af6e9c29257 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -919,7 +919,7 @@ def _prepare_inputs( if self.pcp_size == 1: slot_mapping[ total_num_scheduled_tokens:num_input_tokens].fill_(-1) - slot_mapping = blk_table.slot_mapping.gpu + slot_mapping = blk_table.slot_mapping.gpu if self.pcp_size > 1: self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( total_num_scheduled_tokens, self.query_lens, From e221d076c97bad7e06c59c172c9e5798edb764f8 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Fri, 26 Dec 2025 14:51:22 +0800 Subject: [PATCH 34/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index af6e9c29257..f734b190828 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -210,7 +210,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): if self.pcp_size > 1: self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs max_buffer_num_tokens = self.max_num_tokens - if self.pcp_size > 1: + if self.pcp_size * self.dp_size > 1: max_buffer_num_tokens = (self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size) self.pcp_manager = PCPManager( @@ -916,11 +916,11 @@ def _prepare_inputs( blk_table_tensor = blk_table.get_device_tensor() slot_mapping = blk_table.slot_mapping.gpu[: maybe_pcp_full_tokens] - if self.pcp_size == 1: + if self.pcp_size * self.dcp_size == 1: slot_mapping[ total_num_scheduled_tokens:num_input_tokens].fill_(-1) slot_mapping = blk_table.slot_mapping.gpu - if self.pcp_size > 1: + if self.pcp_size * self.dcp_size > 1: self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( total_num_scheduled_tokens, self.query_lens, self.attn_mask, self.input_batch) @@ -1778,7 +1778,7 @@ def _build_dummy_attn_metadata( self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) - long_seq_metadata = None if self.pcp_size == 1 else self.pcp_manager.generate_pcp_metadata( + long_seq_metadata = None if self.pcp_size * self.dp_size == 1 else self.pcp_manager.generate_pcp_metadata( num_tokens, self.query_lens, self.attn_mask, self.input_batch) if long_seq_metadata is not None: From 3d4669b7d5df4aa47c248ddfe35e92dfb6b63fd7 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sun, 28 Dec 2025 17:06:07 +0800 Subject: [PATCH 35/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/ops/linear_op.py | 1 + vllm_ascend/worker/model_runner_v1.py | 7 ------- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 674dab54e0c..83db355ad19 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -158,6 +158,7 @@ class MLPColumnParallelOp(CustomColumnParallelOp): def __init__(self, layer): super().__init__(layer) + print("layer is }", layer) @property def comm_group(self): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1e663cb288f..3b64794c945 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -228,9 +228,6 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): dtype=torch.int32) self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) - # Ascend-specific configurations - self.ascend_config = get_ascend_config() - set_weight_prefetch_method(self.ascend_config.weight_prefetch_config) self.sampler = AscendSampler() self.attn_mask = None self.attn_state = None @@ -1132,10 +1129,6 @@ def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens, def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): - if self.shared_kv_cache_layers is not None: - # sharing kv across layers need to read the kvcache, - # directly return chunked prefill in this scenario - return AscendAttentionState.ChunkedPrefill if np.all(self.input_batch.num_computed_tokens_cpu[:num_reqs] == 0): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. From 7ae893d168e78966240131396f48845fea0c3bd7 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Sun, 28 Dec 2025 17:07:27 +0800 Subject: [PATCH 36/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/ops/linear_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 83db355ad19..674dab54e0c 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -158,7 +158,6 @@ class MLPColumnParallelOp(CustomColumnParallelOp): def __init__(self, layer): super().__init__(layer) - print("layer is }", layer) @property def comm_group(self): From 1218d5d304c8aa2963a6a08a805031338ac01ec2 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 29 Dec 2025 09:42:30 +0800 Subject: [PATCH 37/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3b64794c945..8918fd7cad7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -209,7 +209,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): if self.pcp_size > 1: self.model_config.max_model_len += 2 * self.pcp_size * self.max_num_reqs max_buffer_num_tokens = self.max_num_tokens - if self.pcp_size * self.dp_size > 1: + if self.pcp_size * self.dcp_size > 1: max_buffer_num_tokens = (self.max_num_tokens + self.max_num_reqs * 2 * self.pcp_size) self.pcp_manager = PCPManager( @@ -1770,7 +1770,7 @@ def _build_dummy_attn_metadata( self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) - long_seq_metadata = None if self.pcp_size * self.dp_size == 1 else self.pcp_manager.generate_pcp_metadata( + long_seq_metadata = None if self.pcp_size * self.dcp_size == 1 else self.pcp_manager.generate_pcp_metadata( num_tokens, self.query_lens, self.attn_mask, self.input_batch) if long_seq_metadata is not None: From 588547aecd51fdb4f5513d97572c618be5842308 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 29 Dec 2025 15:49:05 +0800 Subject: [PATCH 38/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/pcp_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 20cd90c47b6..369c77e8995 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -489,15 +489,15 @@ def generate_kv_idx(self, scheduler_output, input_batch): rank * chunk_size + num_added_recover_tokens]) - cp_kv_recover_idx_for_chunk = torch.from_numpy( - np.concatenate(self.cp_kv_recover_idx_for_chunk)).to( - device=self.device) - cp_kv_recover_idx_for_chunk.copy_(torch.tensor( - np.array( - self.cp_kv_recover_idx_for_chunk).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( - torch.float32).argsort().to(torch.int32) + cp_kv_recover_idx_for_chunk = torch.from_numpy( + np.concatenate(self.cp_kv_recover_idx_for_chunk)).to( + device=self.device) + cp_kv_recover_idx_for_chunk.copy_(torch.tensor( + np.array( + self.cp_kv_recover_idx_for_chunk).flatten().tolist()), + non_blocking=True) + self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( + torch.float32).argsort().to(torch.int32) def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, attn_mask, input_batch): From cd8166457a4f396cea8c96ced863ac0c32b7cbc0 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Mon, 29 Dec 2025 16:31:19 +0800 Subject: [PATCH 39/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/pcp_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 369c77e8995..851993876a8 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -490,11 +490,10 @@ def generate_kv_idx(self, scheduler_output, input_batch): num_added_recover_tokens]) cp_kv_recover_idx_for_chunk = torch.from_numpy( - np.concatenate(self.cp_kv_recover_idx_for_chunk)).to( - device=self.device) + np.concatenate( + self.cp_kv_recover_idx_for_chunk)).to(device=self.device) cp_kv_recover_idx_for_chunk.copy_(torch.tensor( - np.array( - self.cp_kv_recover_idx_for_chunk).flatten().tolist()), + np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()), non_blocking=True) self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( torch.float32).argsort().to(torch.int32) From 4e30363b622c313cdc39bfc484736ea5ddab750b Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Tue, 30 Dec 2025 01:03:44 +0800 Subject: [PATCH 40/43] cleancode Signed-off-by: zhenwenqi2024 --- tests/e2e/multicard/long_sequence/test_accuracy.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/e2e/multicard/long_sequence/test_accuracy.py b/tests/e2e/multicard/long_sequence/test_accuracy.py index af11189992b..09afb0874f4 100644 --- a/tests/e2e/multicard/long_sequence/test_accuracy.py +++ b/tests/e2e/multicard/long_sequence/test_accuracy.py @@ -71,7 +71,7 @@ def test_models_long_sequence_output_between_tp_and_cp( "prefill_context_parallel_size": 2, "compilation_config": { "cudagraph_mode": "FULL_DECODE_ONLY", - "cudagraph_capture_sizes": [4, 8, 24, 48, 60] + "cudagraph_capture_sizes": [2, 4, 8, 24, 48, 60] }, } tp_kwargs = { diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f4e9208de1c..22d4783ca04 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -945,6 +945,7 @@ def _prepare_inputs( self.long_seq_metadata = self.pcp_manager.generate_pcp_metadata( total_num_scheduled_tokens, self.query_lens, self.attn_mask, self.input_batch) + blk_table.slot_mapping.gpu[maybe_pcp_full_tokens:].fill_(-1) slot_mapping = slot_mapping[:maybe_pcp_full_tokens] slot_mapping = self.pcp_manager.get_padded_slot_mapping( total_num_scheduled_tokens, From 8908b9b86a008090df681f85e36d69bc3e32bf1a Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Tue, 30 Dec 2025 01:06:05 +0800 Subject: [PATCH 41/43] cleancode Signed-off-by: zhenwenqi2024 --- tests/e2e/multicard/long_sequence/test_accuracy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/multicard/long_sequence/test_accuracy.py b/tests/e2e/multicard/long_sequence/test_accuracy.py index 09afb0874f4..af11189992b 100644 --- a/tests/e2e/multicard/long_sequence/test_accuracy.py +++ b/tests/e2e/multicard/long_sequence/test_accuracy.py @@ -71,7 +71,7 @@ def test_models_long_sequence_output_between_tp_and_cp( "prefill_context_parallel_size": 2, "compilation_config": { "cudagraph_mode": "FULL_DECODE_ONLY", - "cudagraph_capture_sizes": [2, 4, 8, 24, 48, 60] + "cudagraph_capture_sizes": [4, 8, 24, 48, 60] }, } tp_kwargs = { From a7dff03901e8cf9b950f4f61440588bc2111ba6f Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Tue, 30 Dec 2025 08:54:42 +0800 Subject: [PATCH 42/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/pcp_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 851993876a8..b892f59713d 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -88,7 +88,7 @@ def __init__( range(self.max_num_tokens * self.pcp_world_size * self.dcp_world_size + self.pcp_world_size * self.dcp_world_size * self.max_num_reqs)) - if self.speculative_config and self.pcp_world_size > 1: + if self.speculative_config and self.pcp_world_size * self.dcp_world_size > 1: self.input_ids_pcp_full = CpuGpuBuffer(self.max_num_tokens, dtype=torch.int32, device=device, From d32806209cfe5e02f0a311f22106802fe1cedbf0 Mon Sep 17 00:00:00 2001 From: zhenwenqi2024 Date: Tue, 30 Dec 2025 14:59:58 +0800 Subject: [PATCH 43/43] cleancode Signed-off-by: zhenwenqi2024 --- vllm_ascend/worker/model_runner_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 86ed70c5d15..319c7e41e0c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -559,7 +559,7 @@ def _prepare_inputs( self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens) # for pcp, prefill mtp should use origin scheduleroutput , - if self.speculative_config and self.pcp_size > 1: + if self.speculative_config and self.pcp_size * self.dcp_size > 1: self.pcp_manager.generate_pcp_mtp_input( num_reqs, total_num_scheduled_tokens, scheduler_output.num_scheduled_tokens, with_prefill,