diff --git a/aiter/dist/communication_op.py b/aiter/dist/communication_op.py index f88716c827..06f6a5d293 100644 --- a/aiter/dist/communication_op.py +++ b/aiter/dist/communication_op.py @@ -32,7 +32,7 @@ def tensor_model_parallel_all_reduce( def tensor_model_parallel_fused_allreduce_rmsnorm( input_: torch.Tensor, weight_: torch.Tensor, eps: float -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor]: return get_tp_group().fused_allreduce_rmsnorm(input_, weight_, eps) diff --git a/aiter/dist/device_communicators/communicator_cuda.py b/aiter/dist/device_communicators/communicator_cuda.py index 339337cc65..14a12ded20 100644 --- a/aiter/dist/device_communicators/communicator_cuda.py +++ b/aiter/dist/device_communicators/communicator_cuda.py @@ -158,10 +158,14 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor: torch.distributed.all_reduce(out, group=self.device_group) return out - def fused_allreduce_rmsnorm(self, input_, weight_, eps) -> torch.Tensor: + def fused_allreduce_rmsnorm( + self, input_, weight_, eps + ) -> tuple[torch.Tensor, torch.Tensor]: n = input_.shape[-1] can_use_fuse_ar_rms = ( - n <= 16384 and input_.numel() * input_.element_size() < 8 * 1024 * 8192 + n <= 16384 + and input_.numel() * input_.element_size() < 8 * 1024 * 8192 + and self.world_size != 6 ) ca_comm = self.ca_comm if ( @@ -170,11 +174,12 @@ def fused_allreduce_rmsnorm(self, input_, weight_, eps) -> torch.Tensor: and ca_comm.should_custom_ar(input_) and can_use_fuse_ar_rms ): - out = ca_comm.custom_fused_ar_rms(input_, weight_, eps) + res_out, out = ca_comm.custom_fused_ar_rms(input_, weight_, eps) assert out is not None - return out + assert res_out is not None + return res_out, out # call split kernel - ar_out = all_reduce(input_) + ar_out = self.all_reduce(input_) out = torch.empty_like(ar_out) residual_out = torch.empty_like(ar_out) from aiter import rmsnorm2d_fwd_with_add @@ -188,7 +193,7 @@ def fused_allreduce_rmsnorm(self, input_, weight_, eps) -> torch.Tensor: eps, 0, ) - return out + return residual_out, out def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): world_size = self.world_size diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index 7a788b9046..2df5f2b606 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -340,6 +340,7 @@ def fused_ar_rms( self, inp: torch.Tensor, *, + res_out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, w: torch.Tensor, eps: float, @@ -347,15 +348,18 @@ def fused_ar_rms( ): if out is None: out = torch.empty_like(inp) + if res_out is None: + res_out = torch.empty_like(inp) ops.fused_allreduce_rmsnorm( self._ptr, inp, + res_out, out, w, eps, None if registered else self.buffer, ) - return out + return res_out, out def custom_fused_ar_rms( self, input: torch.Tensor, weight: torch.Tensor, eps: float @@ -367,7 +371,7 @@ def custom_fused_ar_rms( if torch.cuda.is_current_stream_capturing(): return self.fused_ar_rms(input, w=weight, eps=eps, registered=True) else: - return torch.empty_like(input) + return torch.empty_like(input), torch.empty_like(input) else: return self.fused_ar_rms(input, w=weight, eps=eps, registered=False) diff --git a/aiter/dist/parallel_state.py b/aiter/dist/parallel_state.py index ec438224c0..5c6bab34f5 100644 --- a/aiter/dist/parallel_state.py +++ b/aiter/dist/parallel_state.py @@ -128,7 +128,7 @@ def fused_allreduce_rmsnorm_fake( @torch_compile_guard(gen_fake=fused_allreduce_rmsnorm_fake) def fused_allreduce_rmsnorm_( inp: torch.Tensor, w: torch.Tensor, eps: float, group_name: str -) -> torch.Tensor: +) -> tuple[torch.Tensor, torch.Tensor]: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: @@ -348,14 +348,14 @@ def _all_reduce_out_place( def fused_allreduce_rmsnorm( self, input_: torch.Tensor, weight_: torch.Tensor, eps: float - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: return fused_allreduce_rmsnorm_( input_, weight_, eps, group_name=self.unique_name ) def _fused_allreduce_rmsnorm_out_place( self, input_: torch.Tensor, weight_: torch.Tensor, eps: float - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.fused_allreduce_rmsnorm(input_, weight_, eps) @@ -861,6 +861,7 @@ def get_pp_group() -> GroupCoordinator: from typing import Optional + _DP: Optional[GroupCoordinator] = None diff --git a/aiter/ops/custom_all_reduce.py b/aiter/ops/custom_all_reduce.py index e11bbb23b2..34a6c7f100 100644 --- a/aiter/ops/custom_all_reduce.py +++ b/aiter/ops/custom_all_reduce.py @@ -45,6 +45,7 @@ def all_gather_unreg( def fused_allreduce_rmsnorm( _fa: int, inp: torch.Tensor, + res_out: torch.Tensor, out: torch.Tensor, w: torch.Tensor, eps: float, diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index a03ebe99f0..29d60f2d2c 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -976,6 +976,7 @@ namespace aiter template __global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm_naive( RankSignals sg, + T* __restrict__ residual_out, T* __restrict__ results, T* __restrict__ weight, float eps, @@ -1028,6 +1029,7 @@ namespace aiter } int write_idx = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x; *(reinterpret_cast(results) + write_idx) = rmsnorm_rslt; + *(reinterpret_cast(residual_out) + write_idx) = rmsnorm_inp[n_iter]; } } } @@ -1039,6 +1041,7 @@ namespace aiter template __global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm( RankSignals sg, + T* __restrict__ residual_out, T* __restrict__ results, T* __restrict__ weight, float eps, @@ -1096,6 +1099,7 @@ namespace aiter } int write_idx = bid * (n / pack_size) + n_iter * tnum + threadIdx.x; *(reinterpret_cast(results) + write_idx) = rmsnorm_rslt; + *(reinterpret_cast(residual_out) + write_idx) = rmsnorm_inp[n_iter]; } } } @@ -1104,6 +1108,7 @@ namespace aiter template __global__ void __launch_bounds__(256, 1) local_device_load_rmsnorm_512n( RankSignals sg, + T* __restrict__ residual_out, T* __restrict__ results, T* __restrict__ weight, float eps, @@ -1156,6 +1161,7 @@ namespace aiter } int write_idx = bid * 64 * n_loop + n_iter * 64 + lane_id; *(reinterpret_cast(results) + write_idx) = rmsnorm_rslt; + *(reinterpret_cast(residual_out) + write_idx) = rmsnorm_inp[n_iter]; } } } @@ -1489,17 +1495,17 @@ namespace aiter name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); -#define dispatch(ngpus, name) \ - do \ - { \ - if (bytes % 128 == 0) \ - { \ - KL(ngpus, name) \ - } \ - else \ - { \ - KL(ngpus, name##_naive) \ - } \ +#define dispatch(ngpus, name) \ + do \ + { \ + if (bytes % 128 == 0 && world_size_ != 6) \ + { \ + KL(ngpus, name) \ + } \ + else \ + { \ + KL(ngpus, name##_naive) \ + } \ } while(0) #define REDUCE_CASE(ngpus) \ @@ -1581,7 +1587,7 @@ namespace aiter } template - void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* output, T* weight, float eps, int m, int n) + void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* residual_out, T* output, T* weight, float eps, int m, int n) { auto d = packed_t::P::size; int size = m * n; @@ -1627,12 +1633,12 @@ namespace aiter grid.x = naive_grid_size < num_cu * occupancy ? naive_grid_size : num_cu * occupancy; }; -#define launch_fused_allreduce_rmsnorm(template_kernel) \ - do \ - { \ - auto kernel_ptr = reinterpret_cast(template_kernel); \ - setGrid(naive_grid_size, kernel_ptr); \ - template_kernel<<>>(sg_, output, weight, eps, rank_, m, n); \ +#define launch_fused_allreduce_rmsnorm(template_kernel) \ + do \ + { \ + auto kernel_ptr = reinterpret_cast(template_kernel); \ + setGrid(naive_grid_size, kernel_ptr); \ + template_kernel<<>>(sg_, residual_out, output, weight, eps, rank_, m, n); \ } while (0) if (n_bytes % 1024 == 0) diff --git a/csrc/include/custom_all_reduce.h b/csrc/include/custom_all_reduce.h index aa2ba561bb..98cc9ec4e2 100644 --- a/csrc/include/custom_all_reduce.h +++ b/csrc/include/custom_all_reduce.h @@ -40,6 +40,7 @@ void all_gather_unreg(fptr_t _fa, torch::Tensor& out); void fused_allreduce_rmsnorm(fptr_t _fa, torch::Tensor& inp, + torch::Tensor& res_out, torch::Tensor& out, torch::Tensor& w, float eps, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 7926085f17..899020a1dd 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -327,6 +327,7 @@ namespace py = pybind11; &aiter::fused_allreduce_rmsnorm, \ py::arg("_fa"), \ py::arg("inp"), \ + py::arg("res_out"), \ py::arg("out"), \ py::arg("w"), \ py::arg("eps"), \ diff --git a/csrc/kernels/custom_all_reduce.cu b/csrc/kernels/custom_all_reduce.cu index 7fddb6ed45..02d7a674ce 100644 --- a/csrc/kernels/custom_all_reduce.cu +++ b/csrc/kernels/custom_all_reduce.cu @@ -217,7 +217,7 @@ void all_gather_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, } void _fused_allreduce_rmsnorm( - fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream) + fptr_t _fa, torch::Tensor& inp, torch::Tensor& residual_out, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream) { auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); @@ -226,6 +226,7 @@ void _fused_allreduce_rmsnorm( case at::ScalarType::Float: { fa->dispatchFusedAllReduceRMSNorm(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(residual_out.data_ptr()), reinterpret_cast(out.data_ptr()), reinterpret_cast(w.data_ptr()), eps, m, n); @@ -234,6 +235,7 @@ void _fused_allreduce_rmsnorm( case at::ScalarType::Half: { fa->dispatchFusedAllReduceRMSNorm(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(residual_out.data_ptr()), reinterpret_cast(out.data_ptr()), reinterpret_cast(w.data_ptr()), eps, m, n); @@ -243,6 +245,7 @@ void _fused_allreduce_rmsnorm( case at::ScalarType::BFloat16: { fa->dispatchFusedAllReduceRMSNorm<__hip_bfloat16>(stream, reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(residual_out.data_ptr()), reinterpret_cast<__hip_bfloat16*>(out.data_ptr()), reinterpret_cast<__hip_bfloat16*>(w.data_ptr()), eps, m, n); @@ -256,6 +259,7 @@ void _fused_allreduce_rmsnorm( void fused_allreduce_rmsnorm(fptr_t _fa, torch::Tensor& inp, + torch::Tensor& res_out, torch::Tensor& out, torch::Tensor& w, float eps, @@ -278,11 +282,11 @@ void fused_allreduce_rmsnorm(fptr_t _fa, input_size, hipMemcpyDeviceToDevice, stream)); - _fused_allreduce_rmsnorm(_fa, reg_buffer.value(), out, w, eps, m, n, stream); + _fused_allreduce_rmsnorm(_fa, reg_buffer.value(), res_out, out, w, eps, m, n, stream); } else { - _fused_allreduce_rmsnorm(_fa, inp, out, w, eps, m, n, stream); + _fused_allreduce_rmsnorm(_fa, inp, res_out, out, w, eps, m, n, stream); } } diff --git a/op_tests/multigpu_tests/test_fused_ar_rms.py b/op_tests/multigpu_tests/test_fused_ar_rms.py index 33e0e16b93..3d15e257bc 100644 --- a/op_tests/multigpu_tests/test_fused_ar_rms.py +++ b/op_tests/multigpu_tests/test_fused_ar_rms.py @@ -62,8 +62,11 @@ def fused_ar_rmsnorm(tp_size, pp_size, rankID, x, weight, eps, withGraph=False): graph = torch.cuda.CUDAGraph() with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): - out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm( + x, weight, eps + ) out.fill_(0) + res_out.fill_(0) @perftest() def run_ca(): @@ -75,7 +78,8 @@ def run_ca(): @perftest() def run_ca(x): - return tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + return out out = run_ca(x) @@ -113,7 +117,7 @@ def get_acc_value_with_cudagraph(tp_size, pp_size, rankID, x, weight, eps, loop_ with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): # out = torch.empty_like(x) - out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) out.fill_(0) def run_ca(): @@ -154,7 +158,7 @@ def get_acc_value_only(tp_size, pp_size, rankID, x, weight, eps, loop_time=1): torch.cuda.synchronize() for i in range(loop_time): - out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + res, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) # destroy if dist.is_initialized(): @@ -238,19 +242,6 @@ def run_ca(x): return out -def run_cu(input, weight, eps, device_id): - device = f"cuda:{device_id}" - input = input.to(device) - weight = weight.to(device) - - @perftest() - def compute(): - output = torch.empty_like(input) - aiter.rms_norm_cu(output, input, weight, eps) - - return compute() - - @benchmark() def test_split_ar_rmsnorm(tp_size, pp_size, shape, dtype, withGraph=False): os.environ["MASTER_ADDR"] = "127.0.0.1" @@ -275,7 +266,6 @@ def test_split_ar_rmsnorm(tp_size, pp_size, shape, dtype, withGraph=False): pool.apply_async( split_ar_rmsnorm, args=(tp_size, pp_size, i, x, weight, eps, withGraph) ) - # pool.apply_async(run_cu, args=(x, weight, eps, i)) ) pool.close() pool.join() @@ -320,7 +310,6 @@ def test_fused_ar_rmsnorm(tp_size, pp_size, shape, dtype, withGraph=False): pool.apply_async( fused_ar_rmsnorm, args=(tp_size, pp_size, i, x, weight, eps, withGraph) ) - # pool.apply_async(run_cu, args=(x, weight, eps, i)) ) pool.close() pool.join()