diff --git a/aiter/dist/device_communicators/cuda_communicator.py b/aiter/dist/device_communicators/communicator_cuda.py similarity index 88% rename from aiter/dist/device_communicators/cuda_communicator.py rename to aiter/dist/device_communicators/communicator_cuda.py index d5e8c7629b..97c3eefb71 100644 --- a/aiter/dist/device_communicators/cuda_communicator.py +++ b/aiter/dist/device_communicators/communicator_cuda.py @@ -38,18 +38,20 @@ def __init__( CustomAllreduce, ) - # from aiter.dist.device_communicators.pynccl import PyNcclCommunicator - # from aiter.dist.device_communicators.symm_mem import SymmMemCommunicator self.pynccl_comm = None - # if self.world_size > 1: - # self.pynccl_comm = PyNcclCommunicator( - # group=self.cpu_group, - # device=self.device, - # ) - # if is_symmetric_memory_enabled(): - # register_nccl_symmetric_ops(self.pynccl_comm) + if self.world_size > 1: + from aiter.dist.device_communicators.communicator_pynccl import ( + PyNcclCommunicator, + ) + + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + # if is_symmetric_memory_enabled(): + # register_nccl_symmetric_ops(self.pynccl_comm) self.ca_comm: CustomAllreduce | None = None self.qr_comm = None @@ -70,8 +72,7 @@ def __init__( # ), ) - # if current_platform.is_rocm(): - if True and self.world_size > 1: + if self.world_size > 1: from aiter.dist.device_communicators.quick_all_reduce import ( QuickAllReduce, ) @@ -118,14 +119,6 @@ def __init__( ) def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor: - # since currently we perform copy input -> symm_input -> out-of-place AR - # return symm_output, we don't need to check if input is symmetric - if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce( - self.pynccl_comm.world_size, input_ - ): - out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) - if out is not None: - return out # always try quick reduce first, then custom allreduce, # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm @@ -153,19 +146,16 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor: assert out is not None return out pynccl_comm = self.pynccl_comm - if pynccl_comm is None or pynccl_comm.disabled: - out = input_.clone() - torch.distributed.all_reduce(out, group=self.device_group) + if pynccl_comm is not None and not pynccl_comm.disabled: + out = pynccl_comm.all_reduce(input_) + assert out is not None return out - assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) - if out is None: - # fall back to the default all-reduce using PyTorch. - # this usually happens during testing. - # when we run the model, allreduce only happens for the TP - # group, where we always have either custom allreduce or pynccl. - out = input_.clone() - torch.distributed.all_reduce(out, group=self.device_group) + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) return out def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): @@ -259,6 +249,8 @@ def recv( def destroy(self): if self.pynccl_comm is not None: self.pynccl_comm = None + if self.qr_comm is not None: + self.qr_comm = None if self.ca_comm is not None: self.ca_comm = None if self.all2all_manager is not None: diff --git a/aiter/dist/device_communicators/communicator_pynccl.py b/aiter/dist/device_communicators/communicator_pynccl.py new file mode 100644 index 0000000000..c0bfaf358d --- /dev/null +++ b/aiter/dist/device_communicators/communicator_pynccl.py @@ -0,0 +1,381 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from .pynccl_wrapper import ( + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) +from aiter import logger + +current_stream = torch.cuda.current_stream +_NCCL_SYMM_OPS_REGISTERED = False + + +# def register_nccl_symmetric_ops(pynccl_comm): +# from vllm.distributed.device_communicators.pynccl_allocator import ( +# nccl_symm_mem_context, +# ) +# from vllm.utils.torch_utils import direct_register_custom_op + +# global _NCCL_SYMM_OPS_REGISTERED +# if _NCCL_SYMM_OPS_REGISTERED: +# return +# _NCCL_SYMM_OPS_REGISTERED = True + +# def all_reduce_symmetric_with_copy_impl(input_tensor: torch.Tensor) -> torch.Tensor: +# with nccl_symm_mem_context(pynccl_comm): +# symm_input = torch.empty_like(input_tensor) +# symm_output = torch.empty_like(input_tensor) +# symm_input.copy_(input_tensor) +# symm_output = pynccl_comm.all_reduce(symm_input, symm_output) +# return symm_output + +# def all_reduce_symmetric_with_copy_fake(input_tensor: torch.Tensor) -> torch.Tensor: +# return torch.empty_like(input_tensor) + +# direct_register_custom_op( +# op_name="all_reduce_symmetric_with_copy", +# op_func=all_reduce_symmetric_with_copy_impl, +# fake_impl=all_reduce_symmetric_with_copy_fake, +# ) + + +class PyNcclCommunicator: + def __init__( + self, + group: ProcessGroup, + device: int | str | torch.device, + library_path: str | None = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bound to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if isinstance(group, ProcessGroup): + assert dist.is_initialized() + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception as e: + print(f"Failed to load NCCL library: {e}") + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + self.nccl_version = self.nccl.ncclGetRawVersion() + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + logger.info(f"load NCCL version: {self.nccl_version}") + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if isinstance(group, ProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank + ) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + stream.synchronize() + del data + + def all_reduce( + self, + in_tensor: torch.Tensor, + out_tensor: torch.Tensor = None, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}" + ) + + if out_tensor is None: + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + return out_tensor + + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def all_gatherv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + assert output_tensor.shape[0] == sum(sizes) + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + dst_slice = output_tensor[split_offset : split_offset + split_size] + self.nccl.ncclBroadcast( + buffer_type(input_tensor.data_ptr()), + buffer_type(dst_slice.data_ptr()), + dst_slice.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() + + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatterv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + chunk = input_tensor[split_offset : split_offset + split_size, ...] + self.nccl.ncclReduce( + buffer_type(chunk.data_ptr()), + buffer_type(output_tensor.data_ptr()), + chunk.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def group_start(self): + self.nccl.ncclGroupStart() + + def group_end(self): + self.nccl.ncclGroupEnd() + + def register_comm_window(self, tensor: torch.Tensor): + return self.nccl.ncclCommWindowRegister( + self.comm, + buffer_type(tensor.data_ptr()), + tensor.numel() * tensor.element_size(), + 1, + ) + + def register_comm_window_raw(self, ptr: int, size: int): + return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1) + + def deregister_comm_window(self, window): + return self.nccl.ncclCommWindowDeregister(self.comm, window) diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index 84600ef011..12e6ee9a56 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -258,7 +258,7 @@ def should_custom_ar(self, inp: torch.Tensor): # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. if self.world_size == 2 or self.fully_connected: - return inp_size < self.max_size + return inp_size <= self.max_size return False def all_reduce( diff --git a/aiter/dist/device_communicators/pynccl_wrapper.py b/aiter/dist/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000000..fbe261e558 --- /dev/null +++ b/aiter/dist/device_communicators/pynccl_wrapper.py @@ -0,0 +1,557 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any + +import torch +from torch.distributed import ReduceOp +from aiter import logger + + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p +ncclWindow_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: list[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function( + "ncclCommInitRank", + ncclResult_t, + [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int], + ), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, int root, + # ncclComm_t comm, cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclSend", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclRecv", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclResult_t ncclGroupStart(); + Function("ncclGroupStart", ncclResult_t, []), + # ncclResult_t ncclGroupEnd(); + Function("ncclGroupEnd", ncclResult_t, []), + # ncclResult_t ncclCommWindowRegister( + # ncclComm_t comm, void* buff, size_t size, + # ncclWindow_t* win, int winFlags); + # Function( + # "ncclCommWindowRegister", + # ncclResult_t, + # [ + # ncclComm_t, + # buffer_type, + # ctypes.c_size_t, + # ctypes.POINTER(ncclWindow_t), + # ctypes.c_int, + # ], + # ), + # # ncclResult_t ncclCommWindowDeregister( + # # ncclComm_t comm, ncclWindow_t win); + # Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + + def __init__(self, so_file: str | None = None): + so_file = so_file or "librccl.so.1" + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s. " + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s. " + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + try: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + except AttributeError: + if func.name in [ + "ncclCommWindowRegister", + "ncclCommWindowDeregister", + ]: + logger.warning( + "The symbol %s is not found in the NCCL " + "library %s. To enable VLLM_USE_NCCL_SYMM_MEM " + " please update your NCCL version to >= " + "2.27.03.", + func.name, + so_file, + ) + # Having an exception here on ROCm platform is + # not allowed during graph capturing + continue + raise + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetRawVersion(self) -> int: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + # something like 21903 + return version.value + + def ncclGetVersion(self) -> str: + version_str = str(self.ncclGetRawVersion()) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) + return unique_id + + def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId: + if len(data) != 128: + raise ValueError( + f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes" + ) + unique_id = ncclUniqueId() + ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128) + return unique_id + + def ncclCommInitRank( + self, world_size: int, unique_id: ncclUniqueId, rank: int + ) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK( + self._funcs["ncclCommInitRank"]( + ctypes.byref(comm), world_size, unique_id, rank + ) + ) + return comm + + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllReduce"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduce"]( + sendbuff, recvbuff, count, datatype, op, root, comm, stream + ) + ) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + + def ncclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + ) + + def ncclRecv( + self, + recvbuff: buffer_type, + count: int, + datatype: int, + src: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + ) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + def ncclGroupStart(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupStart"]()) + + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + + def ncclCommWindowRegister( + self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int + ) -> ncclWindow_t: + window = ncclWindow_t() + self.NCCL_CHECK( + self._funcs["ncclCommWindowRegister"]( + comm, buff, size, ctypes.byref(window), win_flags + ) + ) + return window + + def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window)) + + +__all__ = [ + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", +] diff --git a/aiter/dist/parallel_state.py b/aiter/dist/parallel_state.py index 611d9583a8..4ddf94ce70 100644 --- a/aiter/dist/parallel_state.py +++ b/aiter/dist/parallel_state.py @@ -213,7 +213,7 @@ def __init__( ) self.device_communicator = None if use_device_communicator and self.world_size > 1: - from .device_communicators.cuda_communicator import CudaCommunicator + from .device_communicators.communicator_cuda import CudaCommunicator self.device_communicator = CudaCommunicator( cpu_group=self.cpu_group, @@ -277,7 +277,7 @@ def graph_capture( # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() - from aiter.dist.device_communicators.cuda_communicator import ( + from aiter.dist.device_communicators.communicator_cuda import ( CudaCommunicator, ) @@ -329,7 +329,7 @@ def _all_reduce_out_place( return self.device_communicator.all_reduce(input_, ca_fp8_quant) def _all_gather_out_place(self, input_: torch.Tensor) -> torch.Tensor: - ca_comm = self.ca_comm + ca_comm = self.device_communicator.ca_comm assert ca_comm is not None assert not ca_comm.disabled out = ca_comm.custom_all_gather(input_) diff --git a/aiter/ops/communication.py b/aiter/ops/communication.py index e5940dae9f..2620a13fe5 100644 --- a/aiter/ops/communication.py +++ b/aiter/ops/communication.py @@ -77,7 +77,7 @@ def all_reduce_rmsnorm( input: Tensor, residual_in: Tensor, weight: Tensor, bias: Tensor, epsilon: float ): tp_grp = get_tp_group() - ca = tp_grp.ca_comm + ca = tp_grp.device_communicator.ca_comm return aiter.all_reduce_rmsnorm_( input, @@ -101,7 +101,7 @@ def all_reduce_rmsnorm_quant( epsilon: float, ): tp_grp = get_tp_group() - ca = tp_grp.ca_comm + ca = tp_grp.device_communicator.ca_comm return aiter.all_reduce_rmsnorm_quant_( input,