diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9782e1765e4..6f4d3cda9ff 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1445,6 +1445,8 @@ def _pool( hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]], + finished_receiving: Optional[set[str]], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1479,7 +1481,8 @@ def _pool( logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - ) + finished_sending=finished_sending, + finished_recving=finished_receiving) @torch.inference_mode() def execute_model( @@ -1515,6 +1518,9 @@ def execute_model( if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: + if finished_sending or finished_recving: + hidden_states.finished_sending = finished_sending + hidden_states.finished_recving = finished_recving return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict( @@ -1523,7 +1529,8 @@ def execute_model( else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + num_scheduled_tokens_np, + finished_sending, finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 15dba8b38d2..4988ef4689f 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -17,6 +17,7 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py # +import copy from typing import Optional import torch @@ -27,7 +28,8 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest @@ -35,7 +37,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.ascend_config import init_ascend_config @@ -204,9 +206,18 @@ def execute_model( assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - return None + if not has_kv_transfer_group(): + return None + + new_output = EMPTY_MODEL_RUNNER_OUTPUT + if output.finished_sending or output.finished_recving: + new_output = copy.copy(new_output) + new_output.finished_sending = output.finished_sending + new_output.finished_recving = output.finished_recving + output = new_output + assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: