diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index c057d0cbdf8..4219ee52a56 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -77,6 +77,62 @@ def _test_ddp_gather_uneven_tensors_multidim(rank: int, worldsize: int = NUM_PRO assert (val == torch.ones_like(val)).all() +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test only considers tensors of the same shape across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(50, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test considers tensors of different shapes across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + def _test_ddp_compositional_tensor(rank: int, worldsize: int = NUM_PROCESSES) -> None: dummy = DummyMetricSum() dummy._reductions = {"x": torch.sum} @@ -105,6 +161,76 @@ def test_ddp(process): pytest.pool.map(process, range(NUM_PROCESSES)) +def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test only considers tensors of the same shape across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(50, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + +def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. + + This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in + preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained + with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. + This test considers tensors of different shapes across different ranks. + + Note that this test only works for torch>=2.0. + + """ + tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) + result = gather_all_tensors(tensor) + assert len(result) == worldsize + scalar1 = 0 + scalar2 = 0 + for idx in range(worldsize): + if idx == rank: + scalar1 = scalar1 + torch.sum(tensor * torch.ones_like(tensor)) + else: + scalar1 = scalar1 + torch.sum(result[idx] * torch.ones_like(result[idx])) + scalar2 = scalar2 + torch.sum(result[idx] * torch.ones_like(result[idx])) + gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] + gradient2 = torch.autograd.grad(scalar2, [tensor])[0] + assert torch.allclose(gradient1, gradient2) + + +@pytest.mark.DDP() +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") +@pytest.mark.parametrize( + "process", + [ + _test_ddp_gather_autograd_same_shape, + _test_ddp_gather_autograd_different_shape, + ], +) +def test_ddp_autograd(process): + """Test ddp functions for autograd compatibility.""" + pytest.pool.map(process, range(NUM_PROCESSES)) + + def _test_non_contiguous_tensors(rank): class DummyCatMetric(Metric): full_state_update = True