diff --git a/aiter/dist/device_communicators/custom_all_reduce.py b/aiter/dist/device_communicators/custom_all_reduce.py index 2df5f2b606..abf2794588 100644 --- a/aiter/dist/device_communicators/custom_all_reduce.py +++ b/aiter/dist/device_communicators/custom_all_reduce.py @@ -300,7 +300,7 @@ def custom_all_reduce( else: # if warm up, mimic the allocation pattern # since custom allreduce is out-of-place - return torch.empty_like(input) + return torch.zeros_like(input) else: # note: outside of cuda graph context, # custom allreduce incurs a cost of cudaMemcpy, which should @@ -332,7 +332,7 @@ def custom_all_gather(self, inp: torch.Tensor) -> Optional[torch.Tensor]: return self.all_gather_reg(inp) else: print("allgather capture hipgraph error") - return torch.empty_like(inp) + return torch.zeros_like(inp) else: return self.all_gather_unreg(inp) @@ -371,7 +371,7 @@ def custom_fused_ar_rms( if torch.cuda.is_current_stream_capturing(): return self.fused_ar_rms(input, w=weight, eps=eps, registered=True) else: - return torch.empty_like(input), torch.empty_like(input) + return torch.zeros_like(input), torch.zeros_like(input) else: return self.fused_ar_rms(input, w=weight, eps=eps, registered=False)