diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index 275bb1ea6c7..29c0d8072a6 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -161,10 +161,13 @@ def build( (len(local_context_lens_allranks)), dtype=torch.int32, device=self.device) - cp_kv_recover_idx_for_chunk = common_long_seq_metadata.cp_kv_recover_idx_for_chunk kv_inverse_idx_for_chunk = torch.argsort( - cp_kv_recover_idx_for_chunk.to(torch.float32) - ) if cp_kv_recover_idx_for_chunk is not None else None + common_long_seq_metadata. + pcp_allgather_restore_idx[pcp_size * + num_decode_tokens:].to( + torch.float32)) + cp_kv_recover_idx_for_chunk = torch.argsort( + kv_inverse_idx_for_chunk) batch_chunk_seq_mask = ( local_context_lens_allranks[:, self.pcp_rank, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index a19a0ed4558..b99ebb6d664 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -40,8 +40,6 @@ def enable_cp(): class AscendPrefillContextParallelMetadata: pcp_allgather_restore_idx: torch.Tensor = None - cp_kv_recover_idx_for_chunk: torch.Tensor = None - num_actual_tokens_pcp_padded: int = 0 num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index fb771aac5d8..ff7bc89149e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -553,9 +553,6 @@ def _prepare_inputs( self.num_spec_tokens) if self.pcp_size > 1: - if not self.vllm_config.model_config.use_mla: - self.pcp_manager.generate_kv_idx(scheduler_output, - self.input_batch) num_scheduled_tokens[: num_reqs], position_pcp = self.pcp_manager.update_tokens_for_pcp( num_scheduled_tokens[:num_reqs], diff --git a/vllm_ascend/worker/pcp_utils.py b/vllm_ascend/worker/pcp_utils.py index 40294807e55..cbff8a4b307 100644 --- a/vllm_ascend/worker/pcp_utils.py +++ b/vllm_ascend/worker/pcp_utils.py @@ -17,12 +17,11 @@ # Adapted from vllm-project/vllm/vllm/worker/worker.py # -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING import numpy as np import torch from vllm.config import VllmConfig -from vllm.utils.math_utils import cdiv from vllm.v1.utils import CpuGpuBuffer if TYPE_CHECKING: @@ -86,9 +85,6 @@ def __init__( ) self.num_actual_tokens_pcp_padded = 0 self.pcp_unpad_mask_cpu = self.pcp_unpad_mask_cpu_tensor.numpy() - self.cp_kv_recover_idx_for_chunk: List[List[int]] = [ - [] for _ in range(self.pcp_world_size) - ] self.full_indices = list( range(self.max_num_tokens * self.pcp_world_size * self.dcp_world_size + self.pcp_world_size * @@ -563,47 +559,6 @@ def _get_cp_local_seq_lens( [-1, pcp_world_size, dcp_world_size]) return dcp_local_seq_lens - def generate_kv_idx(self, scheduler_output, input_batch): - if not self.pcp_world_size > 1: - return - self.cp_kv_recover_idx_for_chunk = [[] - for _ in range(self.pcp_world_size) - ] - - for i, req_id in enumerate(input_batch.req_ids): - num_scheduled_token = scheduler_output.num_scheduled_tokens[req_id] - is_prefill = num_scheduled_token > self.decode_threshold - if is_prefill: - num_cp_padded_scheduled_tokens = cdiv( - num_scheduled_token, - 2 * self.pcp_world_size) * (2 * self.pcp_world_size) - chunk_size = num_cp_padded_scheduled_tokens // ( - 2 * self.pcp_world_size) - num_added_recover_tokens = len( - self.cp_kv_recover_idx_for_chunk[0]) * self.pcp_world_size - for rank in range(self.pcp_world_size): - self.cp_kv_recover_idx_for_chunk[rank].extend( - self.full_indices[rank * chunk_size + - num_added_recover_tokens:(rank + 1) * - chunk_size + - num_added_recover_tokens]) - self.cp_kv_recover_idx_for_chunk[rank].extend( - self.full_indices[num_cp_padded_scheduled_tokens - - (rank + 1) * chunk_size + - num_added_recover_tokens: - num_cp_padded_scheduled_tokens - - rank * chunk_size + - num_added_recover_tokens]) - - cp_kv_recover_idx_for_chunk = torch.from_numpy( - np.concatenate( - self.cp_kv_recover_idx_for_chunk)).to(device=self.device) - cp_kv_recover_idx_for_chunk.copy_(torch.tensor( - np.array(self.cp_kv_recover_idx_for_chunk).flatten().tolist()), - non_blocking=True) - self.cp_kv_recover_idx_for_chunk = cp_kv_recover_idx_for_chunk.to( - torch.float32).argsort().to(torch.int32) - def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, input_batch, num_scheduled_tokens): from vllm_ascend.attention.utils import \ @@ -774,7 +729,6 @@ def generate_pcp_metadata(self, total_num_scheduled_tokens, query_lens, } long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx.gpu[: num_actual_tokens_pcp_padded] - long_seq_metadata.cp_kv_recover_idx_for_chunk = self.cp_kv_recover_idx_for_chunk long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor long_seq_metadata.q_full_idx = self.q_full_idx