Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aiter/dist/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@


def tensor_model_parallel_all_reduce(
input_: torch.Tensor, open_fp8_quant: bool = False
input_: torch.Tensor, use_new: bool = False, open_fp8_quant: bool = False
) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_, open_fp8_quant)
return get_tp_group().all_reduce(input_, use_new, open_fp8_quant)


def tensor_model_parallel_fused_allreduce_rmsnorm(
Expand Down
6 changes: 4 additions & 2 deletions aiter/dist/device_communicators/communicator_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def __init__(
self.all2all_manager.__class__.__name__,
)

def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor:
def all_reduce(
self, input_, use_new: bool = False, ca_fp8_quant: bool = False
) -> torch.Tensor:
# always try quick reduce first, then custom allreduce,
# and then pynccl. (quick reduce just for ROCM MI3*)
qr_comm = self.qr_comm
Expand All @@ -137,7 +139,7 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor:
and not ca_comm.disabled
and ca_comm.should_custom_ar(input_)
):
out = ca_comm.custom_all_reduce(input_, ca_fp8_quant)
out = ca_comm.custom_all_reduce(input_, use_new, ca_fp8_quant)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
Expand Down
11 changes: 8 additions & 3 deletions aiter/dist/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def all_reduce(
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None,
use_new: bool = False,
open_fp8_quant: bool = False,
registered: bool = False,
):
Expand All @@ -281,21 +282,25 @@ def all_reduce(
self._ptr,
inp,
out,
use_new,
open_fp8_quant,
None if registered else self.buffer,
)
return out

def custom_all_reduce(
self, input: torch.Tensor, open_fp8_quant: bool = False
self, input: torch.Tensor, use_new: bool = False, open_fp8_quant: bool = False
) -> Optional[torch.Tensor]:
# when custom allreduce is disabled, this will be None
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(
input, open_fp8_quant=open_fp8_quant, registered=True
input,
use_new=use_new,
open_fp8_quant=open_fp8_quant,
registered=True,
)
else:
# if warm up, mimic the allocation pattern
Expand All @@ -307,7 +312,7 @@ def custom_all_reduce(
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
return self.all_reduce(
input, open_fp8_quant=open_fp8_quant, registered=False
input, use_new=use_new, open_fp8_quant=open_fp8_quant, registered=False
)

def all_gather_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
Expand Down
15 changes: 9 additions & 6 deletions aiter/dist/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def all_reduce_fake(
# There is same name all_reduce in aiter.op, use Alias
@torch_compile_guard(gen_fake=all_reduce_fake)
def all_reduce_(
tensor: torch.Tensor, group_name: str, ca_fp8_quant: bool
tensor: torch.Tensor, group_name: str, ca_use_new: bool, ca_fp8_quant: bool
) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce_out_place(tensor, ca_fp8_quant)
return group._all_reduce_out_place(tensor, ca_use_new, ca_fp8_quant)


def fused_allreduce_rmsnorm_fake(
Expand Down Expand Up @@ -323,7 +323,7 @@ def graph_capture(
yield graph_capture_context

def all_reduce(
self, input_: torch.Tensor, ca_fp8_quant: bool = False
self, input_: torch.Tensor, ca_use_new: bool = False, ca_fp8_quant: bool = False
) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
Expand All @@ -344,15 +344,18 @@ def all_reduce(
return input_

return all_reduce_(
input_, group_name=self.unique_name, ca_fp8_quant=ca_fp8_quant
input_,
group_name=self.unique_name,
ca_use_new=ca_use_new,
ca_fp8_quant=ca_fp8_quant,
)

def _all_reduce_out_place(
self, input_: torch.Tensor, ca_fp8_quant: bool
self, input_: torch.Tensor, ca_use_new: bool, ca_fp8_quant: bool
) -> torch.Tensor:
if self.device_communicator is None:
raise ValueError("No device communicator found")
return self.device_communicator.all_reduce(input_, ca_fp8_quant)
return self.device_communicator.all_reduce(input_, ca_use_new, ca_fp8_quant)

def fused_allreduce_rmsnorm(
self,
Expand Down
1 change: 1 addition & 0 deletions aiter/ops/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def all_reduce(
_fa: int,
inp: torch.Tensor,
out: torch.Tensor,
use_new: bool,
open_fp8_quant: bool,
reg_buffer: Optional[torch.Tensor] = None,
) -> None: ...
Expand Down
107 changes: 75 additions & 32 deletions csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,7 @@ namespace aiter
* will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(hipStream_t stream, T *input, T *output, int size,
void allreduce(hipStream_t stream, T *input, T *output, int size, bool use_new = false,
#ifndef USE_ROCM
int threads = 512, int block_limit = 20){
#else
Expand All @@ -1479,32 +1479,36 @@ namespace aiter

auto bytes = size * sizeof(T);
size /= d;
int blocks = 16;
bool call_1stage = false;
bool call_2stage = false;
if (world_size_ == 2)
{
call_1stage = true;
}
else if (full_nvlink_)

// use new version of allreduce kernel
if (use_new)
{
if ((world_size_ <= 4 && bytes < 160 * 1024) || (world_size_ <= 8 && bytes < 80 * 1024))
int blocks = 16;
bool call_1stage = false;
bool call_2stage = false;
if (world_size_ == 2)
{
call_1stage = true;
}
else
else if (full_nvlink_)
{
call_2stage = true;
if ((world_size_ <= 4 && bytes < 160 * 1024) || (world_size_ <= 8 && bytes < 80 * 1024))
{
call_1stage = true;
}
else
{
call_2stage = true;
}
}
if (call_1stage)
{
blocks = std::min(kMaxBlocks, (size + (threads / world_size_) - 1) / (threads / world_size_));
}
else if (call_2stage)
{
blocks = std::min(kMaxBlocks, (size / world_size_ + (threads / world_size_) - 1) / (threads / world_size_));
}
}
if (call_1stage)
{
blocks = std::min(kMaxBlocks, (size + (threads / world_size_) - 1) / (threads / world_size_));
}
else if (call_2stage)
{
blocks = std::min(kMaxBlocks, (size / world_size_ + (threads / world_size_) - 1) / (threads / world_size_));
}

#define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
Expand Down Expand Up @@ -1537,17 +1541,56 @@ namespace aiter
break; \
}

switch (world_size_)
{
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
switch (world_size_)
{
REDUCE_CASE(2)
REDUCE_CASE(4)
REDUCE_CASE(6)
REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
}
else // use vllm allreduce kernel
{
int blocks = std::min(block_limit, (size + threads - 1) / threads);
#define VLLM_REDUCE_CASE(ngpus) \
case ngpus: \
{ \
if (world_size_ == 2) \
{ \
KL(ngpus, cross_device_reduce_1stage); \
} \
else if (full_nvlink_) \
{ \
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
(world_size_ <= 8 && bytes < 256 * 1024)) \
{ \
KL(ngpus, cross_device_reduce_1stage_naive); \
} \
else \
{ \
KL(ngpus, cross_device_reduce_2stage_naive); \
} \
} \
break; \
}

switch (world_size_)
{
VLLM_REDUCE_CASE(2)
VLLM_REDUCE_CASE(4)
VLLM_REDUCE_CASE(6)
VLLM_REDUCE_CASE(8)
default:
throw std::runtime_error(
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
"gpus = " +
std::to_string(world_size_));
}
}
#undef REDUCE_CASE
#undef KL
Expand Down
1 change: 1 addition & 0 deletions csrc/include/custom_all_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ fptr_t init_custom_ar(torch::Tensor& meta,
void all_reduce(fptr_t _fa,
torch::Tensor& inp,
torch::Tensor& out,
bool use_new,
bool open_fp8_quant,
std::optional<torch::Tensor> reg_buffer);
void all_gather_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ namespace py = pybind11;
py::arg("_fa"), \
py::arg("inp"), \
py::arg("out"), \
py::arg("use_new"), \
py::arg("open_fp8_quant"), \
py::arg("reg_buffer") = std::nullopt); \
m.def("fused_allreduce_rmsnorm", \
Expand Down
13 changes: 7 additions & 6 deletions csrc/kernels/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ bool _is_weak_contiguous(torch::Tensor& t)
}

void _all_reduce(
fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, hipStream_t stream, bool open_fp8_quant)
fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, hipStream_t stream, bool use_new, bool open_fp8_quant)
{
auto fa = reinterpret_cast<aiter::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
Expand All @@ -91,7 +91,7 @@ void _all_reduce(
fa->allreduce<float>(stream,
reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
out.numel());
out.numel(), use_new);
break;
}
case at::ScalarType::Half: {
Expand All @@ -111,7 +111,7 @@ void _all_reduce(
fa->allreduce<half>(stream,
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
out.numel());
out.numel(), use_new);
}
break;
}
Expand All @@ -120,7 +120,7 @@ void _all_reduce(
fa->allreduce<__hip_bfloat16>(stream,
reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()),
reinterpret_cast<__hip_bfloat16*>(out.data_ptr()),
out.numel());
out.numel(), use_new);
break;
}
#endif
Expand All @@ -132,6 +132,7 @@ void _all_reduce(
void all_reduce(fptr_t _fa,
torch::Tensor& inp,
torch::Tensor& out,
bool use_new,
bool open_fp8_quant,
std::optional<torch::Tensor> reg_buffer)
{
Expand All @@ -150,11 +151,11 @@ void all_reduce(fptr_t _fa,
input_size,
hipMemcpyDeviceToDevice,
stream));
_all_reduce(_fa, reg_buffer.value(), out, stream, open_fp8_quant);
_all_reduce(_fa, reg_buffer.value(), out, stream, use_new, open_fp8_quant);
}
else
{
_all_reduce(_fa, inp, out, stream, open_fp8_quant);
_all_reduce(_fa, inp, out, stream, use_new, open_fp8_quant);
}


Expand Down