Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,20 @@ def __init__(
CustomAllreduce,
)

# from aiter.dist.device_communicators.pynccl import PyNcclCommunicator

# from aiter.dist.device_communicators.symm_mem import SymmMemCommunicator

self.pynccl_comm = None
# if self.world_size > 1:
# self.pynccl_comm = PyNcclCommunicator(
# group=self.cpu_group,
# device=self.device,
# )
# if is_symmetric_memory_enabled():
# register_nccl_symmetric_ops(self.pynccl_comm)
if self.world_size > 1:
from aiter.dist.device_communicators.communicator_pynccl import (
PyNcclCommunicator,
)

self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
# if is_symmetric_memory_enabled():
# register_nccl_symmetric_ops(self.pynccl_comm)

self.ca_comm: CustomAllreduce | None = None
self.qr_comm = None
Expand All @@ -70,8 +72,7 @@ def __init__(
# ),
)

# if current_platform.is_rocm():
if True and self.world_size > 1:
if self.world_size > 1:
from aiter.dist.device_communicators.quick_all_reduce import (
QuickAllReduce,
)
Expand Down Expand Up @@ -118,14 +119,6 @@ def __init__(
)

def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor:
# since currently we perform copy input -> symm_input -> out-of-place AR
# return symm_output, we don't need to check if input is symmetric
if self.pynccl_comm is not None and should_nccl_symm_mem_allreduce(
self.pynccl_comm.world_size, input_
):
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
if out is not None:
return out
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
Expand Down Expand Up @@ -153,19 +146,16 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor:
assert out is not None
return out
pynccl_comm = self.pynccl_comm
if pynccl_comm is None or pynccl_comm.disabled:
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
if pynccl_comm is not None and not pynccl_comm.disabled:
out = pynccl_comm.all_reduce(input_)
assert out is not None
return out
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out

def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
Expand Down Expand Up @@ -259,6 +249,8 @@ def recv(
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
if self.qr_comm is not None:
self.qr_comm = None
if self.ca_comm is not None:
self.ca_comm = None
if self.all2all_manager is not None:
Expand Down
Loading