From 69c6cc10e8458deeccd0644535c3f3bb11c630b7 Mon Sep 17 00:00:00 2001 From: ganyi Date: Thu, 31 Jul 2025 11:03:04 +0800 Subject: [PATCH 1/2] adopt the new changes on disaggregated pd from vllm main branch Signed-off-by: ganyi --- vllm_ascend/worker/model_runner_v1.py | 9 ++++++++- vllm_ascend/worker/worker_v1.py | 19 +++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 9782e1765e4..a057fa8ae0f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1445,6 +1445,8 @@ def _pool( hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]], + finished_receiving: Optional[set[str]], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1479,6 +1481,8 @@ def _pool( logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, + finished_sending=finished_sending, + finished_recving=finished_receiving ) @torch.inference_mode() @@ -1515,6 +1519,9 @@ def execute_model( if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: + if finished_sending or finished_recving: + hidden_states.finished_sending = finished_sending + hidden_states.finished_recving = finished_recving return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict( @@ -1523,7 +1530,7 @@ def execute_model( else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + num_scheduled_tokens_np, finished_sending, finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 15dba8b38d2..c405ed92451 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -18,6 +18,7 @@ # from typing import Optional +import copy import torch import torch.nn as nn @@ -27,7 +28,8 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest @@ -35,7 +37,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.ascend_config import init_ascend_config @@ -204,9 +206,18 @@ def execute_model( assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - return None + if not has_kv_transfer_group(): + return None + + new_output = EMPTY_MODEL_RUNNER_OUTPUT + if output.finished_sending or output.finished_recving: + new_output = copy.copy(new_output) + new_output.finished_sending = output.finished_sending + new_output.finished_recving = output.finished_recving + output = new_output + assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: From e61d0ebae339e0f5a8e8e6d4a626087412e06c98 Mon Sep 17 00:00:00 2001 From: ganyi Date: Fri, 1 Aug 2025 09:40:48 +0800 Subject: [PATCH 2/2] fix lint Signed-off-by: ganyi --- vllm_ascend/worker/model_runner_v1.py | 6 +++--- vllm_ascend/worker/worker_v1.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a057fa8ae0f..6f4d3cda9ff 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1482,8 +1482,7 @@ def _pool( prompt_logprobs_dict={}, pooler_output=pooler_output, finished_sending=finished_sending, - finished_recving=finished_receiving - ) + finished_recving=finished_receiving) @torch.inference_mode() def execute_model( @@ -1530,7 +1529,8 @@ def execute_model( else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, finished_recving) + num_scheduled_tokens_np, + finished_sending, finished_recving) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index c405ed92451..4988ef4689f 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -17,8 +17,8 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_worker.py # -from typing import Optional import copy +from typing import Optional import torch import torch.nn as nn