From ba370dec985f7f44bfcb5d4d58a1fb88839a17f9 Mon Sep 17 00:00:00 2001 From: Shengyao Zhuang <46237844+ArvinZhuang@users.noreply.github.com> Date: Thu, 25 Mar 2021 19:37:58 +1000 Subject: [PATCH] Match the number of outputs of backward with forward for AllGatherGrad (#6625) --- pytorch_lightning/utilities/distributed.py | 2 +- tests/utilities/test_all_gather_grad.py | 23 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 9e47af26f53d5..3877f774b7cd8 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -187,7 +187,7 @@ def backward(ctx, *grad_output): torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) - return grad_output[torch.distributed.get_rank()] + return grad_output[torch.distributed.get_rank()], None def all_gather_ddp_if_available( diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index f82cfc94bcce2..86b977cfff029 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -96,3 +96,26 @@ def training_epoch_end(self, outputs) -> None: trainer.fit(model) assert model.training_epoch_end_called + + +@RunIf(min_gpus=2, skip_windows=True, special=True) +def test_all_gather_sync_grads(tmpdir): + + class TestModel(BoringModel): + + training_step_called = False + + def training_step(self, batch, batch_idx): + self.training_step_called = True + tensor = torch.rand(2, 2, requires_grad=True, device=self.device) + gathered_tensor = self.all_gather(tensor, sync_grads=True) + assert gathered_tensor.shape == torch.Size([2, 2, 2]) + + loss = gathered_tensor.sum() + + return loss + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2) + trainer.fit(model) + assert model.training_step_called