diff --git a/aiter/dist/communication_op.py b/aiter/dist/communication_op.py index 7350d05572..7ab6359a87 100644 --- a/aiter/dist/communication_op.py +++ b/aiter/dist/communication_op.py @@ -24,10 +24,10 @@ def tensor_model_parallel_all_reduce( - input_: torch.Tensor, open_fp8_quant: bool = False + input_: torch.Tensor, use_new: bool = False, open_fp8_quant: bool = False ) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" - return get_tp_group().all_reduce(input_, open_fp8_quant) + return get_tp_group().all_reduce(input_, use_new, open_fp8_quant) def tensor_model_parallel_fused_allreduce_rmsnorm( diff --git a/aiter/dist/device_communicators/communicator_cuda.py b/aiter/dist/device_communicators/communicator_cuda.py index 7dc4308d2a..dae87675ed 100644 --- a/aiter/dist/device_communicators/communicator_cuda.py +++ b/aiter/dist/device_communicators/communicator_cuda.py @@ -118,7 +118,9 @@ def __init__( self.all2all_manager.__class__.__name__, ) - def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor: + def all_reduce( + self, input_, use_new: bool = False, ca_fp8_quant: bool = False + ) -> torch.Tensor: # always try quick reduce first, then custom allreduce, # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm @@ -137,7 +139,7 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor: and not ca_comm.disabled and ca_comm.should_custom_ar(input_) ): - out = ca_comm.custom_all_reduce(input_, ca_fp8_quant) + out = ca_comm.custom_all_reduce(input_, use_new, ca_fp8_quant) assert out is not None return out symm_mem_comm = self.symm_mem_comm diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index b00446a8f4..30c999c5bc 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -266,6 +266,7 @@ def all_reduce( inp: torch.Tensor, *, out: Optional[torch.Tensor] = None, + use_new: bool = False, open_fp8_quant: bool = False, registered: bool = False, ): @@ -281,13 +282,14 @@ def all_reduce( self._ptr, inp, out, + use_new, open_fp8_quant, None if registered else self.buffer, ) return out def custom_all_reduce( - self, input: torch.Tensor, open_fp8_quant: bool = False + self, input: torch.Tensor, use_new: bool = False, open_fp8_quant: bool = False ) -> Optional[torch.Tensor]: # when custom allreduce is disabled, this will be None if self.disabled or not self.should_custom_ar(input): @@ -295,7 +297,10 @@ def custom_all_reduce( if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): return self.all_reduce( - input, open_fp8_quant=open_fp8_quant, registered=True + input, + use_new=use_new, + open_fp8_quant=open_fp8_quant, + registered=True, ) else: # if warm up, mimic the allocation pattern @@ -307,7 +312,7 @@ def custom_all_reduce( # be small(<=1% of overall latency) compared to the performance # gains of using custom kernels return self.all_reduce( - input, open_fp8_quant=open_fp8_quant, registered=False + input, use_new=use_new, open_fp8_quant=open_fp8_quant, registered=False ) def all_gather_reg(self, inp: torch.Tensor, out: torch.Tensor = None): diff --git a/aiter/dist/parallel_state.py b/aiter/dist/parallel_state.py index 1d74e136ac..b56c26d7ab 100644 --- a/aiter/dist/parallel_state.py +++ b/aiter/dist/parallel_state.py @@ -110,13 +110,13 @@ def all_reduce_fake( # There is same name all_reduce in aiter.op, use Alias @torch_compile_guard(gen_fake=all_reduce_fake) def all_reduce_( - tensor: torch.Tensor, group_name: str, ca_fp8_quant: bool + tensor: torch.Tensor, group_name: str, ca_use_new: bool, ca_fp8_quant: bool ) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor, ca_fp8_quant) + return group._all_reduce_out_place(tensor, ca_use_new, ca_fp8_quant) def fused_allreduce_rmsnorm_fake( @@ -323,7 +323,7 @@ def graph_capture( yield graph_capture_context def all_reduce( - self, input_: torch.Tensor, ca_fp8_quant: bool = False + self, input_: torch.Tensor, ca_use_new: bool = False, ca_fp8_quant: bool = False ) -> torch.Tensor: """ User-facing all-reduce function before we actually call the @@ -344,15 +344,18 @@ def all_reduce( return input_ return all_reduce_( - input_, group_name=self.unique_name, ca_fp8_quant=ca_fp8_quant + input_, + group_name=self.unique_name, + ca_use_new=ca_use_new, + ca_fp8_quant=ca_fp8_quant, ) def _all_reduce_out_place( - self, input_: torch.Tensor, ca_fp8_quant: bool + self, input_: torch.Tensor, ca_use_new: bool, ca_fp8_quant: bool ) -> torch.Tensor: if self.device_communicator is None: raise ValueError("No device communicator found") - return self.device_communicator.all_reduce(input_, ca_fp8_quant) + return self.device_communicator.all_reduce(input_, ca_use_new, ca_fp8_quant) def fused_allreduce_rmsnorm( self, diff --git a/aiter/ops/custom_all_reduce.py b/aiter/ops/custom_all_reduce.py index 9f139c498a..d9066e1ede 100644 --- a/aiter/ops/custom_all_reduce.py +++ b/aiter/ops/custom_all_reduce.py @@ -26,6 +26,7 @@ def all_reduce( _fa: int, inp: torch.Tensor, out: torch.Tensor, + use_new: bool, open_fp8_quant: bool, reg_buffer: Optional[torch.Tensor] = None, ) -> None: ... diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index cbe19249a9..41c9b1a212 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -1457,7 +1457,7 @@ namespace aiter * will cause contention on NVLink bus. */ template - void allreduce(hipStream_t stream, T *input, T *output, int size, + void allreduce(hipStream_t stream, T *input, T *output, int size, bool use_new = false, #ifndef USE_ROCM int threads = 512, int block_limit = 20){ #else @@ -1479,32 +1479,36 @@ namespace aiter auto bytes = size * sizeof(T); size /= d; - int blocks = 16; - bool call_1stage = false; - bool call_2stage = false; - if (world_size_ == 2) - { - call_1stage = true; - } - else if (full_nvlink_) + + // use new version of allreduce kernel + if (use_new) { - if ((world_size_ <= 4 && bytes < 160 * 1024) || (world_size_ <= 8 && bytes < 80 * 1024)) + int blocks = 16; + bool call_1stage = false; + bool call_2stage = false; + if (world_size_ == 2) { call_1stage = true; } - else + else if (full_nvlink_) { - call_2stage = true; + if ((world_size_ <= 4 && bytes < 160 * 1024) || (world_size_ <= 8 && bytes < 80 * 1024)) + { + call_1stage = true; + } + else + { + call_2stage = true; + } + } + if (call_1stage) + { + blocks = std::min(kMaxBlocks, (size + (threads / world_size_) - 1) / (threads / world_size_)); + } + else if (call_2stage) + { + blocks = std::min(kMaxBlocks, (size / world_size_ + (threads / world_size_) - 1) / (threads / world_size_)); } - } - if (call_1stage) - { - blocks = std::min(kMaxBlocks, (size + (threads / world_size_) - 1) / (threads / world_size_)); - } - else if (call_2stage) - { - blocks = std::min(kMaxBlocks, (size / world_size_ + (threads / world_size_) - 1) / (threads / world_size_)); - } #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ @@ -1537,17 +1541,56 @@ namespace aiter break; \ } - switch (world_size_) - { - REDUCE_CASE(2) - REDUCE_CASE(4) - REDUCE_CASE(6) - REDUCE_CASE(8) - default: - throw std::runtime_error( - "custom allreduce only supports num gpus in (2,4,6,8). Actual num " - "gpus = " + - std::to_string(world_size_)); + switch (world_size_) + { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } + } + else // use vllm allreduce kernel + { + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define VLLM_REDUCE_CASE(ngpus) \ + case ngpus: \ + { \ + if (world_size_ == 2) \ + { \ + KL(ngpus, cross_device_reduce_1stage); \ + } \ + else if (full_nvlink_) \ + { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) \ + { \ + KL(ngpus, cross_device_reduce_1stage_naive); \ + } \ + else \ + { \ + KL(ngpus, cross_device_reduce_2stage_naive); \ + } \ + } \ + break; \ + } + + switch (world_size_) + { + VLLM_REDUCE_CASE(2) + VLLM_REDUCE_CASE(4) + VLLM_REDUCE_CASE(6) + VLLM_REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } } #undef REDUCE_CASE #undef KL diff --git a/csrc/include/custom_all_reduce.h b/csrc/include/custom_all_reduce.h index a5b8a712a5..2eb2f179d3 100644 --- a/csrc/include/custom_all_reduce.h +++ b/csrc/include/custom_all_reduce.h @@ -31,6 +31,7 @@ fptr_t init_custom_ar(torch::Tensor& meta, void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + bool use_new, bool open_fp8_quant, std::optional reg_buffer); void all_gather_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 8cab4d1dfd..328173bbc2 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -321,6 +321,7 @@ namespace py = pybind11; py::arg("_fa"), \ py::arg("inp"), \ py::arg("out"), \ + py::arg("use_new"), \ py::arg("open_fp8_quant"), \ py::arg("reg_buffer") = std::nullopt); \ m.def("fused_allreduce_rmsnorm", \ diff --git a/csrc/kernels/custom_all_reduce.cu b/csrc/kernels/custom_all_reduce.cu index e3e7ea86c6..4c067afb13 100644 --- a/csrc/kernels/custom_all_reduce.cu +++ b/csrc/kernels/custom_all_reduce.cu @@ -81,7 +81,7 @@ bool _is_weak_contiguous(torch::Tensor& t) } void _all_reduce( - fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, hipStream_t stream, bool open_fp8_quant) + fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, hipStream_t stream, bool use_new, bool open_fp8_quant) { auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); @@ -91,7 +91,7 @@ void _all_reduce( fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), - out.numel()); + out.numel(), use_new); break; } case at::ScalarType::Half: { @@ -111,7 +111,7 @@ void _all_reduce( fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), - out.numel()); + out.numel(), use_new); } break; } @@ -120,7 +120,7 @@ void _all_reduce( fa->allreduce<__hip_bfloat16>(stream, reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()), reinterpret_cast<__hip_bfloat16*>(out.data_ptr()), - out.numel()); + out.numel(), use_new); break; } #endif @@ -132,6 +132,7 @@ void _all_reduce( void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + bool use_new, bool open_fp8_quant, std::optional reg_buffer) { @@ -150,11 +151,11 @@ void all_reduce(fptr_t _fa, input_size, hipMemcpyDeviceToDevice, stream)); - _all_reduce(_fa, reg_buffer.value(), out, stream, open_fp8_quant); + _all_reduce(_fa, reg_buffer.value(), out, stream, use_new, open_fp8_quant); } else { - _all_reduce(_fa, inp, out, stream, open_fp8_quant); + _all_reduce(_fa, inp, out, stream, use_new, open_fp8_quant); }