From d3894e1ba65b668a0e6127035625acc5eb3b3b67 Mon Sep 17 00:00:00 2001 From: Chuin Wei Tan <87742566+cw-tan@users.noreply.github.com> Date: Thu, 31 Oct 2024 12:53:24 -0400 Subject: [PATCH] Enable autograd graph to propagate after multi-device syncing for loss functions in `ddp` (#2754) Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 ++ docs/source/pages/overview.rst | 5 +++ src/torchmetrics/utilities/distributed.py | 32 ++++++++++------- tests/unittests/bases/test_ddp.py | 42 +++++++++++++++++++++++ tests/unittests/conftest.py | 15 +++++++- 5 files changed, 83 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55f945dc37e..98b55fab971 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725)) +- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754)) + + ### Changed - Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800)) diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 5dabc545e50..34d0dcbd6fc 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -492,6 +492,11 @@ In practice this means that: A functional metric is differentiable if its corresponding modular metric is differentiable. +.. caution:: + For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph + propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as + loss functions in a DDP setting. + *************************************** Metrics and hyperparameter optimization *************************************** diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index 455d64c4ae0..90239b46af0 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -89,8 +89,11 @@ def class_reduce( def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - torch.distributed.all_gather(gathered_result, result, group) + with torch.no_grad(): + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + torch.distributed.all_gather(gathered_result, result, group) + # to propagate autograd graph from local rank + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result @@ -133,15 +136,18 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens return _simple_gather_all_tensors(result, group, world_size) # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate - pad_dims = [] - pad_by = (max_size - local_size).detach().cpu() - for val in reversed(pad_by): - pad_dims.append(0) - pad_dims.append(val.item()) - result_padded = F.pad(result, pad_dims) - gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] - torch.distributed.all_gather(gathered_result, result_padded, group) - for idx, item_size in enumerate(local_sizes): - slice_param = [slice(dim_size) for dim_size in item_size] - gathered_result[idx] = gathered_result[idx][slice_param] + with torch.no_grad(): + pad_dims = [] + pad_by = (max_size - local_size).detach().cpu() + for val in reversed(pad_by): + pad_dims.append(0) + pad_dims.append(val.item()) + result_padded = F.pad(result, pad_dims) + gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] + torch.distributed.all_gather(gathered_result, result_padded, group) + for idx, item_size in enumerate(local_sizes): + slice_param = [slice(dim_size) for dim_size in item_size] + gathered_result[idx] = gathered_result[idx][slice_param] + # to propagate autograd graph from local rank + gathered_result[torch.distributed.get_rank(group)] = result return gathered_result diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index c057d0cbdf8..07dee96f4da 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -27,6 +27,7 @@ from unittests import NUM_PROCESSES, USE_PYTEST_POOL from unittests._helpers import seed_all from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from unittests.conftest import setup_ddp seed_all(42) @@ -105,6 +106,47 @@ def test_ddp(process): pytest.pool.map(process, range(NUM_PROCESSES)) +def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.""" + setup_ddp(rank, worldsize) + x = (rank + 1) * torch.ones(10, requires_grad=True) + + # random linear transformation, it should really not matter what we do here + a, b = torch.randn(1), torch.randn(1) + y = a * x + b # gradient of y w.r.t. x is a + + result = gather_all_tensors(y) + assert len(result) == worldsize + grad = torch.autograd.grad(result[rank].sum(), x)[0] + assert torch.allclose(grad, a * torch.ones_like(x)) + + +def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None: + """Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.""" + setup_ddp(rank, worldsize) + x = (rank + 1) * torch.ones(rank + 1, 2 - rank, requires_grad=True) + + # random linear transformation, it should really not matter what we do here + a, b = torch.randn(1), torch.randn(1) + y = a * x + b # gradient of y w.r.t. x is a + + result = gather_all_tensors(y) + assert len(result) == worldsize + grad = torch.autograd.grad(result[rank].sum(), x)[0] + assert torch.allclose(grad, a * torch.ones_like(x)) + + +@pytest.mark.DDP() +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.") +@pytest.mark.parametrize( + "process", [_test_ddp_gather_all_autograd_same_shape, _test_ddp_gather_all_autograd_different_shape] +) +def test_ddp_autograd(process): + """Test ddp functions for autograd compatibility.""" + pytest.pool.map(process, range(NUM_PROCESSES)) + + def _test_non_contiguous_tensors(rank): class DummyCatMetric(Metric): full_state_update = True diff --git a/tests/unittests/conftest.py b/tests/unittests/conftest.py index f09f884adeb..58967ba2521 100644 --- a/tests/unittests/conftest.py +++ b/tests/unittests/conftest.py @@ -45,7 +45,17 @@ def use_deterministic_algorithms(): # noqa: PT004 def setup_ddp(rank, world_size): - """Initialize ddp environment.""" + """Initialize ddp environment. + + If a particular test relies on the order of the processes in the pool to be [0, 1, 2, ...], then this function + should be called inside the test to ensure that the processes are initialized in the same order they are used in + the tests. + + Args: + rank: the rank of the process + world_size: the number of processes + + """ global CURRENT_PORT os.environ["MASTER_ADDR"] = "localhost" @@ -55,6 +65,9 @@ def setup_ddp(rank, world_size): if CURRENT_PORT > MAX_PORT: CURRENT_PORT = START_PORT + if torch.distributed.group.WORLD is not None: # if already initialized, destroy the process group + torch.distributed.destroy_process_group() + if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)