diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index ebf2cdf7262..e6afa675339 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -31,7 +31,7 @@ from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized, get_kv_transfer_group, has_kv_transfer_group -from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.distributed.parallel_state import Handle, get_pp_group, get_tp_group from vllm.logger import logger from vllm.lora.request import LoRARequest from vllm.sequence import IntermediateTensors @@ -42,6 +42,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput +from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.workspace import init_workspace_manager @@ -134,6 +135,7 @@ def __init__( WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER + self._pp_send_work: list[Handle] = [] ascend_compilation_config = get_ascend_config().ascend_compilation_config if ascend_compilation_config.enable_npugraph_ex and ascend_compilation_config.enable_static_kernel: @@ -377,6 +379,11 @@ def execute_model( if envs_ascend.MSMONITOR_USE_DAEMON: dp.step() + if self._pp_send_work: + for handle in self._pp_send_work: + handle.wait() + self._pp_send_work = [] + intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 if forward_pass and not get_pp_group().is_first_rank: @@ -386,8 +393,14 @@ def execute_model( all_gather_group = None else: all_gather_group = get_tp_group() - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict(all_gather_group=all_gather_group) + tensor_dict, comm_handles, comm_postprocess = get_pp_group().irecv_tensor_dict( + all_gather_group=all_gather_group + ) + assert tensor_dict is not None + intermediate_tensors = AsyncIntermediateTensors( + tensor_dict, + comm_handles=comm_handles, + comm_postprocess=comm_postprocess, ) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) @@ -403,7 +416,10 @@ def execute_model( all_gather_group = None else: all_gather_group = get_tp_group() - get_pp_group().send_tensor_dict(output.tensors, all_gather_group=all_gather_group) + self._pp_send_work = get_pp_group().isend_tensor_dict( + output.tensors, + all_gather_group=all_gather_group, + ) kv_connector_output = output.kv_connector_output if not kv_connector_output: