From 48e699b6dfaa7a5fe46d8a984722f155c510c4d5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Oct 2024 15:17:51 +0100 Subject: [PATCH] tests --- tests/unittests/bases/test_ddp.py | 90 ++++++++++--------------------- tests/unittests/conftest.py | 15 +++++- 2 files changed, 41 insertions(+), 64 deletions(-) diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index fb44d6ee353..07dee96f4da 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -27,6 +27,7 @@ from unittests import NUM_PROCESSES, USE_PYTEST_POOL from unittests._helpers import seed_all from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from unittests.conftest import setup_ddp seed_all(42) @@ -105,80 +106,43 @@ def test_ddp(process): pytest.pool.map(process, range(NUM_PROCESSES)) -def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: - """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks. +def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.""" + setup_ddp(rank, worldsize) + x = (rank + 1) * torch.ones(10, requires_grad=True) - This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in - preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained - with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. - This test only considers tensors of the same shape across different ranks. + # random linear transformation, it should really not matter what we do here + a, b = torch.randn(1), torch.randn(1) + y = a * x + b # gradient of y w.r.t. x is a - Note that this test only works for torch>=2.0. - - """ - tensor = torch.randn(50, dtype=torch.float64, requires_grad=True) - result = gather_all_tensors(tensor) + result = gather_all_tensors(y) assert len(result) == worldsize - scalar1 = 0 - scalar2 = 0 - for idx in range(worldsize): - W = torch.randn_like(result[idx], requires_grad=False) - if idx == rank: - assert torch.allclose(result[idx], tensor) - scalar1 = scalar1 + torch.sum(tensor * W) - else: - scalar1 = scalar1 + torch.sum(result[idx] * W) - scalar2 = scalar2 + torch.sum(result[idx] * W) - gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] - gradient2 = torch.autograd.grad(scalar2, [tensor])[0] - assert torch.allclose(gradient1, gradient2) - - -def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: - """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks. - - This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in - preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained - with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor. - This test considers tensors of different shapes across different ranks. - - Note that this test only works for torch>=2.0. - - """ - tensor = torch.randn(rank + 1, 2 - rank, dtype=torch.float64, requires_grad=True) - result = gather_all_tensors(tensor) + grad = torch.autograd.grad(result[rank].sum(), x)[0] + assert torch.allclose(grad, a * torch.ones_like(x)) + + +def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.""" + setup_ddp(rank, worldsize) + x = (rank + 1) * torch.ones(rank + 1, 2 - rank, requires_grad=True) + + # random linear transformation, it should really not matter what we do here + a, b = torch.randn(1), torch.randn(1) + y = a * x + b # gradient of y w.r.t. x is a + + result = gather_all_tensors(y) assert len(result) == worldsize - scalar1 = 0 - scalar2 = 0 - for idx in range(worldsize): - W = torch.randn_like(result[idx], requires_grad=False) - if idx == rank: - assert torch.allclose(result[idx], tensor) - scalar1 = scalar1 + torch.sum(tensor * W) - else: - scalar1 = scalar1 + torch.sum(result[idx] * W) - scalar2 = scalar2 + torch.sum(result[idx] * W) - gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0] - gradient2 = torch.autograd.grad(scalar2, [tensor])[0] - assert torch.allclose(gradient1, gradient2) + grad = torch.autograd.grad(result[rank].sum(), x)[0] + assert torch.allclose(grad, a * torch.ones_like(x)) @pytest.mark.DDP() @pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") @pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") -# @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions") -@pytest.mark.parametrize( - "process", - [ - _test_ddp_gather_autograd_same_shape, - _test_ddp_gather_autograd_different_shape, - ], -) @pytest.mark.parametrize( - "index", - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "process", [_test_ddp_gather_all_autograd_same_shape, _test_ddp_gather_all_autograd_different_shape] ) -def test_ddp_autograd(process, index): +def test_ddp_autograd(process): """Test ddp functions for autograd compatibility.""" pytest.pool.map(process, range(NUM_PROCESSES)) diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index f09f884adeb..58967ba2521 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -45,7 +45,17 @@ def use_deterministic_algorithms(): # noqa: PT004 def setup_ddp(rank, world_size): - """Initialize ddp environment.""" + """Initialize ddp environment. + + If a particular test relies on the order of the processes in the pool to be [0, 1, 2, ...], then this function + should be called inside the test to ensure that the processes are initialized in the same order they are used in + the tests. + + Args: + rank: the rank of the process + world_size: the number of processes + + """ global CURRENT_PORT os.environ["MASTER_ADDR"] = "localhost" @@ -55,6 +65,9 @@ def setup_ddp(rank, world_size): if CURRENT_PORT > MAX_PORT: CURRENT_PORT = START_PORT + if torch.distributed.group.WORLD is not None: # if already initialized, destroy the process group + torch.distributed.destroy_process_group() + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)