From 412e88d92f41f6496cdb8f331d31e71be7772c40 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 17:07:48 -0700 Subject: [PATCH 01/13] stash --- vllm/distributed/parallel_state.py | 44 ++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0ebd7a15eab9..05640cb8f1c7 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -376,3 +376,47 @@ def destroy_model_parallel(): _PP_DEVICE_GROUP = None global _PP_GLOBAL_RANKS _PP_GLOBAL_RANKS = None + +def is_in_the_same_node(pg: ProcessGroup): + """ + This is a collective operation that checks if all processes in the group + are in the same node. It tests if all processes are attached to the same + memory system (shared access to shared memory). + """ + assert torch.distributed.get_backend(group) != torch.distributed.Backend.NCCL, ( + "is_in_the_same_node should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + magic_message = b"magic_message" + shm = None + if rank == 0: + # create a shared memory segment + from multiprocessing import shared_memory + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + torch.distributed.broadcast_object_list([shm.name], src=ranks[0], group=pg) + is_in_the_same_node[0] = 1 + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, src=ranks[0], group=pg) + name = recv[0] + try: + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception: + pass + + torch.distributed.barrier(group=pg) + if shm: + shm.close() + torch.distributed.barrier(group=pg) + if rank == 0: + shm.unlink() + + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + return is_in_the_same_node.sum().item() == world_size From 589476e2bcfe4adfa7edd7aa841a9f5c5e2c74fc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 20:14:47 -0700 Subject: [PATCH 02/13] use contextlib.suppress --- vllm/distributed/parallel_state.py | 58 +++++++++++++++++++----------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 05640cb8f1c7..6df41628a3b2 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -3,6 +3,8 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" +import contextlib +from multiprocessing import resource_tracker, shared_memory from typing import List, Optional import torch @@ -377,46 +379,62 @@ def destroy_model_parallel(): global _PP_GLOBAL_RANKS _PP_GLOBAL_RANKS = None + def is_in_the_same_node(pg: ProcessGroup): """ This is a collective operation that checks if all processes in the group are in the same node. It tests if all processes are attached to the same memory system (shared access to shared memory). """ - assert torch.distributed.get_backend(group) != torch.distributed.Backend.NCCL, ( - "is_in_the_same_node should be tested with a non-NCCL group.") + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "is_in_the_same_node should be tested with a non-NCCL group.") # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) + + # local tensor in each process to store the result is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + # global ranks of the processes in the group ranks = torch.distributed.get_process_group_ranks(pg) + magic_message = b"magic_message" shm = None - if rank == 0: - # create a shared memory segment - from multiprocessing import shared_memory - shm = shared_memory.SharedMemory(create=True, size=128) - shm.buf[:len(magic_message)] = magic_message - torch.distributed.broadcast_object_list([shm.name], src=ranks[0], group=pg) - is_in_the_same_node[0] = 1 - else: - recv = [None] - torch.distributed.broadcast_object_list(recv, src=ranks[0], group=pg) - name = recv[0] - try: + + with contextlib.suppress(Exception): + if rank == 0: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + torch.distributed.broadcast_object_list([shm.name], + src=ranks[0], + group=pg) + is_in_the_same_node[0] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=ranks[0], + group=pg) + name = recv[0] shm = shared_memory.SharedMemory(name=name) if shm.buf[:len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 - except Exception: - pass - - torch.distributed.barrier(group=pg) if shm: shm.close() torch.distributed.barrier(group=pg) + + # clean up the shared memory segment if rank == 0: - shm.unlink() - + if shm: + shm.unlink() + else: + if shm: + with contextlib.suppress(Exception): + # fix to https://stackoverflow.com/q/62748654/9191338 + resource_tracker.unregister( + shm._name, "shared_memory") # type: ignore[attr-defined] torch.distributed.all_reduce(is_in_the_same_node, group=pg) + return is_in_the_same_node.sum().item() == world_size From eb02cf31bd2dad721c7692cfdf999a9b7e1e5ec9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 20:22:19 -0700 Subject: [PATCH 03/13] add tests --- .buildkite/test-pipeline.yaml | 1 + tests/distributed/test_same_node.py | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 tests/distributed/test_same_node.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b48ef31bc416..56ee53e82911 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,6 +37,7 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 2 commands: + - torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py new file mode 100644 index 000000000000..99895e61b8bb --- /dev/null +++ b/tests/distributed/test_same_node.py @@ -0,0 +1,7 @@ +from vllm.distributed.parallel_state import is_in_the_same_node +torch.distributed.init_process_group(backend="gloo") +ans = is_in_the_same_node(torch.distributed.group.WORLD) +if ans: + exit(0) +else: + exit(1) \ No newline at end of file From 6c95a690dcffa1377bcb7159dd262ccd1fac03b8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 20:34:01 -0700 Subject: [PATCH 04/13] fix end of line --- tests/distributed/test_same_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 99895e61b8bb..d82337f595bf 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -4,4 +4,4 @@ if ans: exit(0) else: - exit(1) \ No newline at end of file + exit(1) From 4eefa0e7bf43ee035c847118ce287889f24d5f73 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 20:35:13 -0700 Subject: [PATCH 05/13] fix lint --- tests/distributed/test_same_node.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index d82337f595bf..08978b9f0f29 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -1,4 +1,5 @@ from vllm.distributed.parallel_state import is_in_the_same_node + torch.distributed.init_process_group(backend="gloo") ans = is_in_the_same_node(torch.distributed.group.WORLD) if ans: From abd7473cab817d62f344f7411b017b1641634eca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 20:36:03 -0700 Subject: [PATCH 06/13] fix var name --- tests/distributed/test_same_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index 08978b9f0f29..b2a25c3117a2 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -1,8 +1,8 @@ from vllm.distributed.parallel_state import is_in_the_same_node torch.distributed.init_process_group(backend="gloo") -ans = is_in_the_same_node(torch.distributed.group.WORLD) -if ans: +test_result = is_in_the_same_node(torch.distributed.group.WORLD) +if test_result: exit(0) else: exit(1) From 43d6625a5ffc0138d2efbb789ef704aee364adf8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 20:36:33 -0700 Subject: [PATCH 07/13] fix import --- tests/distributed/test_same_node.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index b2a25c3117a2..ee80588928b2 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -1,3 +1,5 @@ +import torch + from vllm.distributed.parallel_state import is_in_the_same_node torch.distributed.init_process_group(backend="gloo") From 32fb1c9eedc03fc998e398f34ab2e2820c4d50f6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 23:54:45 -0700 Subject: [PATCH 08/13] use same host in custom allreduce --- .../device_communicators/custom_all_reduce.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a3902aecb379..4a45ac7dd8e9 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -9,7 +9,7 @@ from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import ( - get_local_rank, get_tensor_model_parallel_cpu_group) + get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node) from vllm.logger import init_logger try: @@ -108,6 +108,13 @@ def __init__(self, assert dist.get_backend(group) != dist.Backend.NCCL, ( "CustomAllreduce should be attached to a non-NCCL group.") + if not is_in_the_same_node(group): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + rank = dist.get_rank(group=self.group) world_size = dist.get_world_size(group=self.group) if world_size == 1: From c60a0237a8314a7bf71b6239cd13231e29a8dfef Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 9 Jun 2024 23:56:34 -0700 Subject: [PATCH 09/13] update tests --- .buildkite/test-pipeline.yaml | 2 +- tests/distributed/test_same_node.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 56ee53e82911..6b12d19ba611 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,7 +37,7 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 2 commands: - - torchrun --nproc-per-node=4 distributed/test_same_node.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py diff --git a/tests/distributed/test_same_node.py b/tests/distributed/test_same_node.py index ee80588928b2..4880bab79069 100644 --- a/tests/distributed/test_same_node.py +++ b/tests/distributed/test_same_node.py @@ -1,10 +1,11 @@ +import os + import torch from vllm.distributed.parallel_state import is_in_the_same_node torch.distributed.init_process_group(backend="gloo") test_result = is_in_the_same_node(torch.distributed.group.WORLD) -if test_result: - exit(0) -else: - exit(1) + +expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" +assert test_result == expected, f"Expected {expected}, got {test_result}" From 6179cb5ea172b2058a65bb833fdea3d8edb257f6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 10 Jun 2024 00:36:20 -0700 Subject: [PATCH 10/13] narrow down the error suppression to OSError --- vllm/distributed/parallel_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6df41628a3b2..9abe42798fcb 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -402,7 +402,7 @@ def is_in_the_same_node(pg: ProcessGroup): magic_message = b"magic_message" shm = None - with contextlib.suppress(Exception): + with contextlib.suppress(OSError): if rank == 0: # create a shared memory segment shm = shared_memory.SharedMemory(create=True, size=128) @@ -431,7 +431,7 @@ def is_in_the_same_node(pg: ProcessGroup): shm.unlink() else: if shm: - with contextlib.suppress(Exception): + with contextlib.suppress(OSError): # fix to https://stackoverflow.com/q/62748654/9191338 resource_tracker.unregister( shm._name, "shared_memory") # type: ignore[attr-defined] From 182f58cb0c280285efaee1139b333dc20cc0008d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 10 Jun 2024 01:10:13 -0700 Subject: [PATCH 11/13] ensure close by finally --- vllm/distributed/parallel_state.py | 45 ++++++++++++++++-------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9abe42798fcb..59fc89d2c864 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -402,27 +402,30 @@ def is_in_the_same_node(pg: ProcessGroup): magic_message = b"magic_message" shm = None - with contextlib.suppress(OSError): - if rank == 0: - # create a shared memory segment - shm = shared_memory.SharedMemory(create=True, size=128) - shm.buf[:len(magic_message)] = magic_message - torch.distributed.broadcast_object_list([shm.name], - src=ranks[0], - group=pg) - is_in_the_same_node[0] = 1 - else: - # try to open the shared memory segment - recv = [None] - torch.distributed.broadcast_object_list(recv, - src=ranks[0], - group=pg) - name = recv[0] - shm = shared_memory.SharedMemory(name=name) - if shm.buf[:len(magic_message)] == magic_message: - is_in_the_same_node[rank] = 1 - if shm: - shm.close() + try: + with contextlib.suppress(OSError): + if rank == 0: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + torch.distributed.broadcast_object_list([shm.name], + src=ranks[0], + group=pg) + is_in_the_same_node[0] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=ranks[0], + group=pg) + name = recv[0] + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + finally: + if shm: + shm.close() + torch.distributed.barrier(group=pg) # clean up the shared memory segment From 602fdcfe72be1c63f19abb3d18190762d4460c62 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 10 Jun 2024 10:42:27 -0700 Subject: [PATCH 12/13] add suppress in cleanup --- vllm/distributed/parallel_state.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 59fc89d2c864..9015e9bff261 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -429,12 +429,12 @@ def is_in_the_same_node(pg: ProcessGroup): torch.distributed.barrier(group=pg) # clean up the shared memory segment - if rank == 0: - if shm: - shm.unlink() - else: - if shm: - with contextlib.suppress(OSError): + with contextlib.suppress(OSError): + if rank == 0: + if shm: + shm.unlink() + else: + if shm: # fix to https://stackoverflow.com/q/62748654/9191338 resource_tracker.unregister( shm._name, "shared_memory") # type: ignore[attr-defined] From 72961a50df4d25a827fa0b9978348519fb36b186 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 10 Jun 2024 19:13:51 -0700 Subject: [PATCH 13/13] add logger for unexpected exception --- vllm/distributed/parallel_state.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 9015e9bff261..b6d1eeff0978 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -422,6 +422,8 @@ def is_in_the_same_node(pg: ProcessGroup): shm = shared_memory.SharedMemory(name=name) if shm.buf[:len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) finally: if shm: shm.close()