From dc2427328c02fc869d1d46eab7c6c52e2dcf2bcc Mon Sep 17 00:00:00 2001 From: roy Date: Sun, 19 May 2024 10:29:10 +0800 Subject: [PATCH 1/4] fix --- .../device_communicators/custom_all_reduce.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 30ee9d1f8a1e..fcfdc07ce324 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -138,21 +138,10 @@ def __init__(self, else: device_ids = list(range(torch.cuda.device_count())) - physical_device_id = device_ids[device.index] - tensor = torch.tensor([physical_device_id], - dtype=torch.int, - device="cpu") - gather_list = [ - torch.tensor([0], dtype=torch.int, device="cpu") - for _ in range(world_size) - ] - dist.all_gather(gather_list, tensor, group=self.group) - physical_device_ids = [t.item() for t in gather_list] - # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - full_nvlink = _is_full_nvlink(physical_device_ids) + full_nvlink = _is_full_nvlink(device_ids) if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" From d3f396dd027f99838c2904fcc60c1382a9f23e6a Mon Sep 17 00:00:00 2001 From: roy Date: Sun, 19 May 2024 13:46:13 +0800 Subject: [PATCH 2/4] apply comments --- .../device_communicators/custom_all_reduce.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index fcfdc07ce324..08b9257b58a7 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -138,10 +138,27 @@ def __init__(self, else: device_ids = list(range(torch.cuda.device_count())) + if len(device_ids) < world_size: + logger.warning( + "Custom allreduce is disabled because this feature is " + "not intended for multi node use case.") + return + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - full_nvlink = _is_full_nvlink(device_ids) + full_nvlink = _is_full_nvlink(physical_device_ids) if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" From 6689ec81f49f6f6ce73e4411220c2d04b32e09a9 Mon Sep 17 00:00:00 2001 From: roy Date: Sun, 19 May 2024 14:30:03 +0800 Subject: [PATCH 3/4] fix check --- .../device_communicators/custom_all_reduce.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 08b9257b58a7..0b8fc895ac54 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -138,12 +138,6 @@ def __init__(self, else: device_ids = list(range(torch.cuda.device_count())) - if len(device_ids) < world_size: - logger.warning( - "Custom allreduce is disabled because this feature is " - "not intended for multi node use case.") - return - physical_device_id = device_ids[device.index] tensor = torch.tensor([physical_device_id], dtype=torch.int, @@ -155,6 +149,13 @@ def __init__(self, dist.all_gather(gather_list, tensor, group=self.group) physical_device_ids = [t.item() for t in gather_list] + # 3 nodes will be like [0,1,2,3,0,1,0,1] + if physical_device_ids[-1] + 1 != world_size: + logger.warning( + "Custom allreduce is disabled because this feature is " + "not intended for multi node use case.") + return + # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink From 4a092ba5e6c95c5b8fb2b03cf0437b042d52cc36 Mon Sep 17 00:00:00 2001 From: roy Date: Sun, 19 May 2024 14:35:27 +0800 Subject: [PATCH 4/4] add sort --- vllm/distributed/device_communicators/custom_all_reduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 0b8fc895ac54..9b4b08118813 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -150,7 +150,7 @@ def __init__(self, physical_device_ids = [t.item() for t in gather_list] # 3 nodes will be like [0,1,2,3,0,1,0,1] - if physical_device_ids[-1] + 1 != world_size: + if sorted(physical_device_ids)[-1] + 1 != world_size: logger.warning( "Custom allreduce is disabled because this feature is " "not intended for multi node use case.")