diff --git a/aiter/dist/parallel_state.py b/aiter/dist/parallel_state.py index bb00aa5710..b2cfcef68e 100644 --- a/aiter/dist/parallel_state.py +++ b/aiter/dist/parallel_state.py @@ -123,9 +123,8 @@ def fused_allreduce_rmsnorm_fake( w: torch.Tensor, eps: float, group_name: str, -) -> torch.Tensor: - return torch.empty_like(inp) - +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(res_inp), torch.empty_like(inp) @torch_compile_guard(gen_fake=fused_allreduce_rmsnorm_fake) def fused_allreduce_rmsnorm_(