From 191099100b031218fd4b3146d61a20d31f1e51d8 Mon Sep 17 00:00:00 2001 From: zhuohuan Date: Mon, 2 Feb 2026 14:11:48 +0800 Subject: [PATCH] [fix]:fix pp errors when applying flashcomm1/sp Signed-off-by: zhuohuan --- vllm_ascend/worker/model_runner_v1.py | 40 ++++++++++++++++++--------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index af41b31a45e..7457dbbfeea 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -935,6 +935,28 @@ def _calc_spec_decode_metadata( logits_indices=logits_indices, ) + def sync_and_slice_intermediate_tensors( + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors | None, + sync_self: bool, + ) -> IntermediateTensors: + assert self.intermediate_tensors is not None + + tp = self.vllm_config.parallel_config.tensor_parallel_size + is_rs = enable_sp(self.vllm_config) + # When sequence parallelism/flashcomm1 is enabled, the intermediate tensor is sharded + # across tensor parallel ranks, so each rank only needs its own slice. + if sync_self: + assert intermediate_tensors is not None + for k, v in intermediate_tensors.items(): + copy_len = num_tokens // tp if is_rs else num_tokens + self.intermediate_tensors[k][:copy_len].copy_(v[:copy_len], non_blocking=True) + + return IntermediateTensors( + {k: v[: num_tokens // tp] if is_rs else v[:num_tokens] for k, v in self.intermediate_tensors.items()} + ) + # TODO: Once the PCP features are complete, it will fully inherit the classes from the VLLM community. def propose_draft_token_ids( self, @@ -2300,23 +2322,15 @@ def _dummy_run( if get_pp_group().is_first_rank: intermediate_tensors = None else: - # When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by - # tp_size; otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading - # to incorrect memory estimation and potentially causing OOM. - intermediate_tokens = num_tokens_padded - if enable_sp(): - tp_size = get_tensor_model_parallel_world_size() - intermediate_tokens = (num_tokens_padded + tp_size - 1) // tp_size if self.intermediate_tensors is None: - max_actual_tokens = self.max_num_tokens - if enable_sp(): - max_actual_tokens = (self.max_num_tokens + tp_size - 1) // tp_size + tp_size = get_tensor_model_parallel_world_size() + max_actual_tokens = ( + (self.max_num_tokens + tp_size - 1) // tp_size if enable_sp() else self.max_num_tokens + ) self.intermediate_tensors = self.model.make_empty_intermediate_tensors( batch_size=max_actual_tokens, dtype=self.dtype, device=self.device ) - intermediate_tensors = IntermediateTensors( - {k: v[:intermediate_tokens] for k, v in self.intermediate_tensors.items()} - ) + intermediate_tensors = self.sync_and_slice_intermediate_tensors(num_tokens_padded, None, False) need_dummy_logits = not is_profile and lmhead_tp_enable() max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len