Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 1 addition & 50 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,56 +440,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.rank, len(done_sending),
len(done_recving))

if self.world_size == 1:
return done_sending, done_recving

# In TP>1 setup, each rank exchanges KVs with its counterpart
# ranks independently. get_finished() runs in a worker creates
# the done_sending and done_recving sets that are sent to the
# scheduler via ModelRunnerOutput by Rank 0. To avoid race
# ensure trnxs are done before adding to finished, Ranks 1 to
# N-1 communicate to Rank 0 once their transaction is done.
# Rank 0 only returns finished once all ranks are complete.
if self.rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving:
self._done_recving_count[req_id] += 1

# Update the counts of how many ranks have finished.
# Get notifies from other ranks that txns are done.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
self._done_recving_count[req_id] += 1
else:
self._done_sending_count[req_id] += 1

# Return ids that have finished on all ranks to the scheduler.
all_done_sending: set[str] = set()
all_done_recving: set[str] = set()
for req_id in list(self._done_recving_count.keys()):
if self._done_recving_count[req_id] == self.world_size:
self._done_recving_count.pop(req_id)
all_done_recving.add(req_id)
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] == self.world_size:
self._done_sending_count.pop(req_id)
all_done_sending.add(req_id)

return all_done_sending, all_done_recving

else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)

# NOTE(rob): unused as only Rank 0 sends to sched.
return done_sending, done_recving
return done_sending, done_recving

def _get_new_notifs(self) -> set[str]:
"""Get req_ids which got a remote xfer message."""
Expand Down
39 changes: 38 additions & 1 deletion vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import traceback
import weakref
from collections import defaultdict
from concurrent.futures import Future
from dataclasses import dataclass
from enum import Enum, auto
Expand Down Expand Up @@ -61,6 +62,12 @@ def _init_executor(self) -> None:
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)

# Track requests finishing in each worker
self._done_sending_count: dict[str, int] = defaultdict(
lambda: tensor_parallel_size)
self._done_recving_count: dict[str, int] = defaultdict(
lambda: tensor_parallel_size)

# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
Expand Down Expand Up @@ -139,7 +146,7 @@ def register_failure_callback(self, callback: FailureCallback):
else:
self.failure_callback = callback

def execute_model(
def execute_model_non_pdisagg(
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
Expand All @@ -149,6 +156,36 @@ def execute_model(
timeout=EXECUTE_MODEL_TIMEOUT_S)
return output

def execute_model(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could selectively replace execute_model with this version in the constructor if PD disagg + TP>1 is in use.

self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
outputs: list[ModelRunnerOutput] = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
timeout=EXECUTE_MODEL_TIMEOUT_S)
rank0_outputs = outputs[0]
if len(outputs) == 1:
return rank0_outputs

done_sending, done_recving = set(), set()
for output in outputs:
for req_id in output.finished_sending or ():
if self._done_sending_count[req_id] == 1:
del self._done_sending_count[req_id]
done_sending.add(req_id)
else:
self._done_sending_count[req_id] -= 1
for req_id in output.finished_recving or ():
if self._done_recving_count[req_id] == 1:
del self._done_recving_count[req_id]
done_recving.add(req_id)
else:
self._done_recving_count[req_id] -= 1
rank0_outputs.finished_recving = done_recving
rank0_outputs.finished_sending = done_sending
return rank0_outputs

def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
Expand Down
10 changes: 9 additions & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
"""A GPU worker class."""
import copy
import gc
import os
from typing import TYPE_CHECKING, Optional

import torch
import torch.distributed
import torch.nn as nn
from v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT

import vllm.envs as envs
from vllm.config import VllmConfig
Expand Down Expand Up @@ -266,7 +268,13 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None
if not self.is_driver_worker:
# No need to transfer entire output for non-zero ranks.
new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
new_output.finished_recving = output.finished_recving
new_output.finished_sending = output.finished_recving
output = new_output
return output

def profile(self, is_start: bool = True):
if self.profiler is None:
Expand Down