From 18a35e5ae8c8b324c80037edb75f985c1ff146b4 Mon Sep 17 00:00:00 2001 From: amd-ruitang3 Date: Thu, 13 Nov 2025 06:29:26 +0000 Subject: [PATCH 1/5] [fix]: fused_ar_rms interface Signed-off-by: amd-ruitang3 --- aiter/dist/communication_op.py | 4 +- .../device_communicators/communicator_cuda.py | 4 +- .../device_communicators/custom_all_reduce.py | 16 ++++-- aiter/dist/parallel_state.py | 32 +++++++++--- aiter/ops/custom_all_reduce.py | 1 + csrc/include/custom_all_reduce.cuh | 51 ++++++++++++------- csrc/include/custom_all_reduce.h | 1 + csrc/include/rocm_ops.hpp | 1 + csrc/kernels/custom_all_reduce.cu | 12 +++-- op_tests/multigpu_tests/test_fused_ar_rms.py | 14 +++-- 10 files changed, 96 insertions(+), 40 deletions(-) diff --git a/aiter/dist/communication_op.py b/aiter/dist/communication_op.py index 06f6a5d293..7350d05572 100644 --- a/aiter/dist/communication_op.py +++ b/aiter/dist/communication_op.py @@ -31,9 +31,9 @@ def tensor_model_parallel_all_reduce( def tensor_model_parallel_fused_allreduce_rmsnorm( - input_: torch.Tensor, weight_: torch.Tensor, eps: float + input_: torch.Tensor, residual_inp_: torch.Tensor, weight_: torch.Tensor, eps: float ) -> tuple[torch.Tensor, torch.Tensor]: - return get_tp_group().fused_allreduce_rmsnorm(input_, weight_, eps) + return get_tp_group().fused_allreduce_rmsnorm(input_, residual_inp_, weight_, eps) def tensor_model_parallel_custom_all_gather(input_: torch.Tensor) -> torch.Tensor: diff --git a/aiter/dist/device_communicators/communicator_cuda.py b/aiter/dist/device_communicators/communicator_cuda.py index 14a12ded20..7dc4308d2a 100644 --- a/aiter/dist/device_communicators/communicator_cuda.py +++ b/aiter/dist/device_communicators/communicator_cuda.py @@ -159,7 +159,7 @@ def all_reduce(self, input_, ca_fp8_quant: bool = False) -> torch.Tensor: return out def fused_allreduce_rmsnorm( - self, input_, weight_, eps + self, input_, res_inp_, weight_, eps ) -> tuple[torch.Tensor, torch.Tensor]: n = input_.shape[-1] can_use_fuse_ar_rms = ( @@ -174,7 +174,7 @@ def fused_allreduce_rmsnorm( and ca_comm.should_custom_ar(input_) and can_use_fuse_ar_rms ): - res_out, out = ca_comm.custom_fused_ar_rms(input_, weight_, eps) + res_out, out = ca_comm.custom_fused_ar_rms(input_, res_inp_, weight_, eps) assert out is not None assert res_out is not None return res_out, out diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index abf2794588..b00446a8f4 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -339,6 +339,7 @@ def custom_all_gather(self, inp: torch.Tensor) -> Optional[torch.Tensor]: def fused_ar_rms( self, inp: torch.Tensor, + res_inp: torch.Tensor, *, res_out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, @@ -353,6 +354,7 @@ def fused_ar_rms( ops.fused_allreduce_rmsnorm( self._ptr, inp, + res_inp, res_out, out, w, @@ -362,18 +364,26 @@ def fused_ar_rms( return res_out, out def custom_fused_ar_rms( - self, input: torch.Tensor, weight: torch.Tensor, eps: float + self, + input: torch.Tensor, + residual_inp: torch.Tensor, + weight: torch.Tensor, + eps: float, ) -> Optional[torch.Tensor]: # when custom allreduce is disabled, this will be None if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.fused_ar_rms(input, w=weight, eps=eps, registered=True) + return self.fused_ar_rms( + input, residual_inp, w=weight, eps=eps, registered=True + ) else: return torch.zeros_like(input), torch.zeros_like(input) else: - return self.fused_ar_rms(input, w=weight, eps=eps, registered=False) + return self.fused_ar_rms( + input, residual_inp, w=weight, eps=eps, registered=False + ) def close(self): if not self.disabled and self._ptr: diff --git a/aiter/dist/parallel_state.py b/aiter/dist/parallel_state.py index 5c6bab34f5..25e1d514e5 100644 --- a/aiter/dist/parallel_state.py +++ b/aiter/dist/parallel_state.py @@ -120,20 +120,28 @@ def all_reduce_( def fused_allreduce_rmsnorm_fake( - inp: torch.Tensor, w: torch.Tensor, eps: float, group_name: str + inp: torch.Tensor, + res_inp: torch.Tensor, + w: torch.Tensor, + eps: float, + group_name: str, ) -> torch.Tensor: return torch.empty_like(inp) @torch_compile_guard(gen_fake=fused_allreduce_rmsnorm_fake) def fused_allreduce_rmsnorm_( - inp: torch.Tensor, w: torch.Tensor, eps: float, group_name: str + inp: torch.Tensor, + res_inp: torch.Tensor, + w: torch.Tensor, + eps: float, + group_name: str, ) -> tuple[torch.Tensor, torch.Tensor]: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") - return group._fused_allreduce_rmsnorm_out_place(inp, w, eps) + return group._fused_allreduce_rmsnorm_out_place(inp, res_inp, w, eps) if supports_custom_op(): @@ -347,18 +355,28 @@ def _all_reduce_out_place( return self.device_communicator.all_reduce(input_, ca_fp8_quant) def fused_allreduce_rmsnorm( - self, input_: torch.Tensor, weight_: torch.Tensor, eps: float + self, + input_: torch.Tensor, + residual_inp_: torch.Tensor, + weight_: torch.Tensor, + eps: float, ) -> tuple[torch.Tensor, torch.Tensor]: return fused_allreduce_rmsnorm_( - input_, weight_, eps, group_name=self.unique_name + input_, residual_inp_, weight_, eps, group_name=self.unique_name ) def _fused_allreduce_rmsnorm_out_place( - self, input_: torch.Tensor, weight_: torch.Tensor, eps: float + self, + input_: torch.Tensor, + residual_inp_: torch.Tensor, + weight_: torch.Tensor, + eps: float, ) -> tuple[torch.Tensor, torch.Tensor]: if self.device_communicator is None: raise ValueError("No device communicator found") - return self.device_communicator.fused_allreduce_rmsnorm(input_, weight_, eps) + return self.device_communicator.fused_allreduce_rmsnorm( + input_, residual_inp_, weight_, eps + ) def _all_gather_out_place(self, input_: torch.Tensor) -> torch.Tensor: ca_comm = self.device_communicator.ca_comm diff --git a/aiter/ops/custom_all_reduce.py b/aiter/ops/custom_all_reduce.py index 34a6c7f100..9f139c498a 100644 --- a/aiter/ops/custom_all_reduce.py +++ b/aiter/ops/custom_all_reduce.py @@ -45,6 +45,7 @@ def all_gather_unreg( def fused_allreduce_rmsnorm( _fa: int, inp: torch.Tensor, + res_inp: torch.Tensor, res_out: torch.Tensor, out: torch.Tensor, w: torch.Tensor, diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index 29d60f2d2c..0ce722f6b4 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -935,9 +935,9 @@ namespace aiter #pragma unroll for (int i = 0; i < pack_size; ++i) { - float res_x = ck_tile::type_convert(input_reg.data[i]); + // float res_x = ck_tile::type_convert(input_reg.data[i]); float sum_x = *(reinterpret_cast(&tmp_smem[0]) + lane_id * pack_size + i); - rslt.data[i] = ck_tile::type_convert(res_x + sum_x); + rslt.data[i] = ck_tile::type_convert(sum_x); } tmps[warp_id][(rank * part + bid) * tnum_gpu + lane_id] = rslt; } @@ -976,6 +976,7 @@ namespace aiter template __global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm_naive( RankSignals sg, + T* __restrict__ residual_inp, T* __restrict__ residual_out, T* __restrict__ results, T* __restrict__ weight, @@ -1000,14 +1001,18 @@ namespace aiter for (int n_iter = 0; n_iter < n_loop; ++n_iter) { int read_idx = bid * n_loop * blockDim.x + n_iter * blockDim.x + threadIdx.x; - rmsnorm_inp[n_iter] = tmps[read_idx]; + P reduce_out_pack = tmps[read_idx]; + P residual_inp_pack = *(reinterpret_cast(residual_inp) + read_idx); w_arr[n_iter] = *(reinterpret_cast(weight) + n_iter * blockDim.x + threadIdx.x); A reduce_pack; #pragma unroll for (int i = 0; i < pack_size; ++i) { - float ar_elem = ck_tile::type_convert(rmsnorm_inp[n_iter].data[i]); - reduce_pack.data[i] = ar_elem * ar_elem; + float res_inp = ck_tile::type_convert(residual_inp_pack.data[i]); + float ar_out = ck_tile::type_convert(reduce_out_pack.data[i]); + float rms_inp = res_inp + ar_out; + rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert(rms_inp); + reduce_pack.data[i] = rms_inp * rms_inp; } square_sum += packReduce(reduce_pack); } @@ -1041,6 +1046,7 @@ namespace aiter template __global__ void __launch_bounds__(tnum, 1) local_device_load_rmsnorm( RankSignals sg, + T* __restrict__ residual_inp, T* __restrict__ residual_out, T* __restrict__ results, T* __restrict__ weight, @@ -1067,14 +1073,18 @@ namespace aiter if (n_iter * tnum + threadIdx.x < (n / pack_size)) { int read_idx = bid * (n / pack_size) + n_iter * tnum + threadIdx.x; - rmsnorm_inp[n_iter] = tmps[read_idx]; + P reduce_out_pack = tmps[read_idx]; + P residual_inp_pack = *(reinterpret_cast(residual_inp) + read_idx); w_arr[n_iter] = *(reinterpret_cast(weight) + n_iter * tnum + threadIdx.x); A reduce_pack; #pragma unroll for (int i = 0; i < pack_size; ++i) { - float ar_elem = ck_tile::type_convert(rmsnorm_inp[n_iter].data[i]); - reduce_pack.data[i] = ar_elem * ar_elem; + float ar_out = ck_tile::type_convert(reduce_out_pack.data[i]); + float res_inp = ck_tile::type_convert(residual_inp_pack.data[i]); + float rms_inp = ar_out + res_inp; + rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert(rms_inp); + reduce_pack.data[i] = rms_inp * rms_inp; } square_sum += packReduce(reduce_pack); } @@ -1108,6 +1118,7 @@ namespace aiter template __global__ void __launch_bounds__(256, 1) local_device_load_rmsnorm_512n( RankSignals sg, + T* __restrict__ residual_inp, T* __restrict__ residual_out, T* __restrict__ results, T* __restrict__ weight, @@ -1134,14 +1145,18 @@ namespace aiter for (int n_iter = 0; n_iter < n_loop; ++n_iter) { int read_idx = bid * 64 * n_loop + n_iter * 64 + lane_id; - rmsnorm_inp[n_iter] = tmps[read_idx]; + P reduce_out_pack = tmps[read_idx]; + P residual_inp_pack = *(reinterpret_cast(residual_inp) + read_idx); w_arr[n_iter] = *(reinterpret_cast(weight) + n_iter * 64 + lane_id); A reduce_pack; #pragma unroll for (int i = 0; i < pack_size; ++i) { - float ar_elem = ck_tile::type_convert(rmsnorm_inp[n_iter].data[i]); - reduce_pack.data[i] = ar_elem * ar_elem; + float ar_out = ck_tile::type_convert(reduce_out_pack.data[i]); + float res_inp = ck_tile::type_convert(residual_inp_pack.data[i]); + float rms_inp = ar_out + res_inp; + rmsnorm_inp[n_iter].data[i] = ck_tile::type_convert(rms_inp); + reduce_pack.data[i] = rms_inp * rms_inp; } float tmp_sum = packReduce(reduce_pack); square_sum += tmp_sum; @@ -1587,7 +1602,7 @@ namespace aiter } template - void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* residual_out, T* output, T* weight, float eps, int m, int n) + void dispatchFusedAllReduceRMSNorm(hipStream_t stream, T* input, T* residual_inp, T* residual_out, T* output, T* weight, float eps, int m, int n) { auto d = packed_t::P::size; int size = m * n; @@ -1633,12 +1648,12 @@ namespace aiter grid.x = naive_grid_size < num_cu * occupancy ? naive_grid_size : num_cu * occupancy; }; -#define launch_fused_allreduce_rmsnorm(template_kernel) \ - do \ - { \ - auto kernel_ptr = reinterpret_cast(template_kernel); \ - setGrid(naive_grid_size, kernel_ptr); \ - template_kernel<<>>(sg_, residual_out, output, weight, eps, rank_, m, n); \ +#define launch_fused_allreduce_rmsnorm(template_kernel) \ + do \ + { \ + auto kernel_ptr = reinterpret_cast(template_kernel); \ + setGrid(naive_grid_size, kernel_ptr); \ + template_kernel<<>>(sg_, residual_inp, residual_out, output, weight, eps, rank_, m, n); \ } while (0) if (n_bytes % 1024 == 0) diff --git a/csrc/include/custom_all_reduce.h b/csrc/include/custom_all_reduce.h index 98cc9ec4e2..a5b8a712a5 100644 --- a/csrc/include/custom_all_reduce.h +++ b/csrc/include/custom_all_reduce.h @@ -40,6 +40,7 @@ void all_gather_unreg(fptr_t _fa, torch::Tensor& out); void fused_allreduce_rmsnorm(fptr_t _fa, torch::Tensor& inp, + torch::Tensor& res_inp, torch::Tensor& res_out, torch::Tensor& out, torch::Tensor& w, diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index a30bb352d4..6aca386744 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -327,6 +327,7 @@ namespace py = pybind11; &aiter::fused_allreduce_rmsnorm, \ py::arg("_fa"), \ py::arg("inp"), \ + py::arg("res_inp"), \ py::arg("res_out"), \ py::arg("out"), \ py::arg("w"), \ diff --git a/csrc/kernels/custom_all_reduce.cu b/csrc/kernels/custom_all_reduce.cu index 02d7a674ce..e3e7ea86c6 100644 --- a/csrc/kernels/custom_all_reduce.cu +++ b/csrc/kernels/custom_all_reduce.cu @@ -217,7 +217,7 @@ void all_gather_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, } void _fused_allreduce_rmsnorm( - fptr_t _fa, torch::Tensor& inp, torch::Tensor& residual_out, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream) + fptr_t _fa, torch::Tensor& inp, torch::Tensor& residual_inp, torch::Tensor& residual_out, torch::Tensor& out, torch::Tensor& w, int eps, int m, int n, hipStream_t stream) { auto fa = reinterpret_cast(_fa); TORCH_CHECK(_is_weak_contiguous(out)); @@ -226,6 +226,7 @@ void _fused_allreduce_rmsnorm( case at::ScalarType::Float: { fa->dispatchFusedAllReduceRMSNorm(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(residual_inp.data_ptr()), reinterpret_cast(residual_out.data_ptr()), reinterpret_cast(out.data_ptr()), reinterpret_cast(w.data_ptr()), @@ -235,6 +236,7 @@ void _fused_allreduce_rmsnorm( case at::ScalarType::Half: { fa->dispatchFusedAllReduceRMSNorm(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(residual_inp.data_ptr()), reinterpret_cast(residual_out.data_ptr()), reinterpret_cast(out.data_ptr()), reinterpret_cast(w.data_ptr()), @@ -245,6 +247,7 @@ void _fused_allreduce_rmsnorm( case at::ScalarType::BFloat16: { fa->dispatchFusedAllReduceRMSNorm<__hip_bfloat16>(stream, reinterpret_cast<__hip_bfloat16*>(inp.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(residual_inp.data_ptr()), reinterpret_cast<__hip_bfloat16*>(residual_out.data_ptr()), reinterpret_cast<__hip_bfloat16*>(out.data_ptr()), reinterpret_cast<__hip_bfloat16*>(w.data_ptr()), @@ -259,6 +262,7 @@ void _fused_allreduce_rmsnorm( void fused_allreduce_rmsnorm(fptr_t _fa, torch::Tensor& inp, + torch::Tensor& res_inp, torch::Tensor& res_out, torch::Tensor& out, torch::Tensor& w, @@ -268,7 +272,9 @@ void fused_allreduce_rmsnorm(fptr_t _fa, const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.scalar_type(), res_inp.scalar_type()); TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_EQ(inp.numel(), res_inp.numel()); int n = w.numel(); int m = inp.numel() / n; @@ -282,11 +288,11 @@ void fused_allreduce_rmsnorm(fptr_t _fa, input_size, hipMemcpyDeviceToDevice, stream)); - _fused_allreduce_rmsnorm(_fa, reg_buffer.value(), res_out, out, w, eps, m, n, stream); + _fused_allreduce_rmsnorm(_fa, reg_buffer.value(), res_inp, res_out, out, w, eps, m, n, stream); } else { - _fused_allreduce_rmsnorm(_fa, inp, res_out, out, w, eps, m, n, stream); + _fused_allreduce_rmsnorm(_fa, inp, res_inp, res_out, out, w, eps, m, n, stream); } } diff --git a/op_tests/multigpu_tests/test_fused_ar_rms.py b/op_tests/multigpu_tests/test_fused_ar_rms.py index 3d15e257bc..eb9a258914 100644 --- a/op_tests/multigpu_tests/test_fused_ar_rms.py +++ b/op_tests/multigpu_tests/test_fused_ar_rms.py @@ -63,7 +63,7 @@ def fused_ar_rmsnorm(tp_size, pp_size, rankID, x, weight, eps, withGraph=False): with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm( - x, weight, eps + x, x, weight, eps ) out.fill_(0) res_out.fill_(0) @@ -78,7 +78,9 @@ def run_ca(): @perftest() def run_ca(x): - res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm( + x, x, weight, eps + ) return out out = run_ca(x) @@ -117,7 +119,9 @@ def get_acc_value_with_cudagraph(tp_size, pp_size, rankID, x, weight, eps, loop_ with graph_capture() as gc: with torch.cuda.graph(graph, stream=gc.stream): # out = torch.empty_like(x) - res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + res_out, out = tensor_model_parallel_fused_allreduce_rmsnorm( + x, x, weight, eps + ) out.fill_(0) def run_ca(): @@ -158,7 +162,7 @@ def get_acc_value_only(tp_size, pp_size, rankID, x, weight, eps, loop_time=1): torch.cuda.synchronize() for i in range(loop_time): - res, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, weight, eps) + res, out = tensor_model_parallel_fused_allreduce_rmsnorm(x, x, weight, eps) # destroy if dist.is_initialized(): @@ -439,7 +443,7 @@ def acc_test_cudagraph_on(tp_size, pp_size, shape, dtype, loop_time=1): l_dtype = ["bf16"] l_shape = [ # (4096, 2048) - (64, 7168) + (128, 1024) # (64, 512 * 99) # (16, 512) ] From 2551ce9c5f4d8d27803ba4c26f837c471942a8eb Mon Sep 17 00:00:00 2001 From: amd-ruitang3 Date: Thu, 13 Nov 2025 06:37:11 +0000 Subject: [PATCH 2/5] delete comment Signed-off-by: amd-ruitang3 --- csrc/include/custom_all_reduce.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index 0ce722f6b4..11b34cfc88 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -935,7 +935,6 @@ namespace aiter #pragma unroll for (int i = 0; i < pack_size; ++i) { - // float res_x = ck_tile::type_convert(input_reg.data[i]); float sum_x = *(reinterpret_cast(&tmp_smem[0]) + lane_id * pack_size + i); rslt.data[i] = ck_tile::type_convert(sum_x); } From 78e1523869602b16557c552c56b0b5f9871dd322 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 13 Nov 2025 06:41:06 +0000 Subject: [PATCH 3/5] change ut case Signed-off-by: root --- op_tests/multigpu_tests/test_fused_ar_rms.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/op_tests/multigpu_tests/test_fused_ar_rms.py b/op_tests/multigpu_tests/test_fused_ar_rms.py index eb9a258914..f1159f361b 100644 --- a/op_tests/multigpu_tests/test_fused_ar_rms.py +++ b/op_tests/multigpu_tests/test_fused_ar_rms.py @@ -442,10 +442,7 @@ def acc_test_cudagraph_on(tp_size, pp_size, shape, dtype, loop_time=1): l_dtype = ["bf16"] l_shape = [ - # (4096, 2048) - (128, 1024) - # (64, 512 * 99) - # (16, 512) + (64, 7168) ] l_tp = [8] l_pp = [1] From b6804984d17a31310dadac77640a27f157f3de4b Mon Sep 17 00:00:00 2001 From: root Date: Thu, 13 Nov 2025 08:34:53 +0000 Subject: [PATCH 4/5] fix ut format err Signed-off-by: root --- op_tests/multigpu_tests/test_fused_ar_rms.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/op_tests/multigpu_tests/test_fused_ar_rms.py b/op_tests/multigpu_tests/test_fused_ar_rms.py index f1159f361b..86f7160dba 100644 --- a/op_tests/multigpu_tests/test_fused_ar_rms.py +++ b/op_tests/multigpu_tests/test_fused_ar_rms.py @@ -441,9 +441,7 @@ def acc_test_cudagraph_on(tp_size, pp_size, shape, dtype, loop_time=1): # checkAllclose(cpu_rslt[i], ar_rslt[i].to(ref)) l_dtype = ["bf16"] -l_shape = [ - (64, 7168) -] +l_shape = [(64, 7168)] l_tp = [8] l_pp = [1] l_graph = [True, False] From a43abfd637f884965a276b1d2b0a373dd9840d43 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 13 Nov 2025 14:41:21 +0000 Subject: [PATCH 5/5] [fix]: ar acc err Signed-off-by: root --- csrc/include/custom_all_reduce.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh index 11b34cfc88..cbe19249a9 100644 --- a/csrc/include/custom_all_reduce.cuh +++ b/csrc/include/custom_all_reduce.cuh @@ -511,6 +511,7 @@ namespace aiter } tmp_out[idx - start] = write_reg; } + __syncthreads(); } end_sync(sg, self_sg, rank);