Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions vllm_ascend/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from vllm.distributed import ensure_model_parallel_initialized, init_distributed_environment
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized, get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.distributed.parallel_state import Handle, get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.sequence import IntermediateTensors
Expand All @@ -42,6 +42,7 @@
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.gpu_worker import AsyncIntermediateTensors
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager

Expand Down Expand Up @@ -134,6 +135,7 @@ def __init__(
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")

self.use_v2_model_runner = envs_vllm.VLLM_USE_V2_MODEL_RUNNER
self._pp_send_work: list[Handle] = []

ascend_compilation_config = get_ascend_config().ascend_compilation_config
if ascend_compilation_config.enable_npugraph_ex and ascend_compilation_config.enable_static_kernel:
Expand Down Expand Up @@ -377,6 +379,11 @@ def execute_model(
if envs_ascend.MSMONITOR_USE_DAEMON:
dp.step()

if self._pp_send_work:
for handle in self._pp_send_work:
handle.wait()
self._pp_send_work = []

intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
Expand All @@ -386,8 +393,14 @@ def execute_model(
all_gather_group = None
else:
all_gather_group = get_tp_group()
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(all_gather_group=all_gather_group)
tensor_dict, comm_handles, comm_postprocess = get_pp_group().irecv_tensor_dict(
Comment thread
pisceskkk marked this conversation as resolved.
all_gather_group=all_gather_group
)
assert tensor_dict is not None
intermediate_tensors = AsyncIntermediateTensors(
tensor_dict,
comm_handles=comm_handles,
comm_postprocess=comm_postprocess,
)

output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
Expand All @@ -403,7 +416,10 @@ def execute_model(
all_gather_group = None
else:
all_gather_group = get_tp_group()
get_pp_group().send_tensor_dict(output.tensors, all_gather_group=all_gather_group)
self._pp_send_work = get_pp_group().isend_tensor_dict(
output.tensors,
all_gather_group=all_gather_group,
)

kv_connector_output = output.kv_connector_output
if not kv_connector_output:
Expand Down
Loading