diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index 6cb0c8f49f3d..83423c02a6f1 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -939,7 +939,7 @@ def replacement(self): def _replacement( input: torch.Tensor, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - residual = torch.empty_like(input) + residual = torch.zeros_like(input) allreduce = self.FUSED_AR_RMSNORM_OP( input_=input, residual=residual,