diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 30ee9d1f8a1e..9b4b08118813 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -149,6 +149,13 @@ def __init__(self, dist.all_gather(gather_list, tensor, group=self.group) physical_device_ids = [t.item() for t in gather_list] + # 3 nodes will be like [0,1,2,3,0,1,0,1] + if sorted(physical_device_ids)[-1] + 1 != world_size: + logger.warning( + "Custom allreduce is disabled because this feature is " + "not intended for multi node use case.") + return + # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink