Skip to content

Commit 8adf32a

Browse files
andoorvejimpang
authored andcommitted
[Bugfix] Fix for multinode crash on 4 PP (vllm-project#6495)
Signed-off-by: Muralidhar Andoorveedu <[email protected]>
1 parent 56e4af4 commit 8adf32a

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

tests/distributed/test_pipeline_parallel.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44

55

66
@pytest.mark.parametrize(
7-
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME",
8-
[
7+
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
98
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
109
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
1110
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
12-
# TODO: figure out why PP=4 tests are flaky
13-
# (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
14-
# (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
11+
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
12+
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
1513
])
1614
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
1715
pp_args = [

vllm/executor/ray_gpu_executor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,27 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
224224
# broadcasted to.
225225
self.non_driver_workers: List[RayWorkerWrapper] = []
226226

227+
tp_driver_worker_ranks = []
228+
non_driver_worker_ranks = []
227229
for idx, rank in enumerate(worker_ranks[1:]):
228230
# We need to skip the driver worker, which we
229231
# do by skipping worker_ranks[0] which is always 0.
230232
if rank % self.parallel_config.tensor_parallel_size == 0:
231233
self.tp_driver_workers.append(self.workers[idx])
234+
tp_driver_worker_ranks.append(rank)
232235
else:
233236
self.non_driver_workers.append(self.workers[idx])
237+
non_driver_worker_ranks.append(rank)
238+
239+
# Enforce rank order for correct rank to return final output.
240+
self.tp_driver_workers = [
241+
worker for _, worker in sorted(
242+
zip(tp_driver_worker_ranks, self.tp_driver_workers))
243+
]
244+
self.non_driver_workers = [
245+
worker for _, worker in sorted(
246+
zip(non_driver_worker_ranks, self.non_driver_workers))
247+
]
234248

235249
def _driver_execute_model(
236250
self, execute_model_req: Optional[ExecuteModelRequest]

0 commit comments

Comments
 (0)