Skip to content

Commit

Permalink
Enable autograd graph to propagate after multi-device syncing for los…
Browse files Browse the repository at this point in the history
…s functions in `ddp` (#2754)

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 31, 2024
1 parent abdd2c4 commit d3894e1
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754))


### Changed

- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))
Expand Down
5 changes: 5 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ In practice this means that:
A functional metric is differentiable if its corresponding modular metric is differentiable.

.. caution::
For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph
propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as
loss functions in a DDP setting.

***************************************
Metrics and hyperparameter optimization
***************************************
Expand Down
32 changes: 19 additions & 13 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,11 @@ def class_reduce(


def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
with torch.no_grad():
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result


Expand Down Expand Up @@ -133,15 +136,18 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
return _simple_gather_all_tensors(result, group, world_size)

# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
with torch.no_grad():
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result
42 changes: 42 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from unittests import NUM_PROCESSES, USE_PYTEST_POOL
from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from unittests.conftest import setup_ddp

seed_all(42)

Expand Down Expand Up @@ -105,6 +106,47 @@ def test_ddp(process):
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
x = (rank + 1) * torch.ones(10, requires_grad=True)

# random linear transformation, it should really not matter what we do here
a, b = torch.randn(1), torch.randn(1)
y = a * x + b # gradient of y w.r.t. x is a

result = gather_all_tensors(y)
assert len(result) == worldsize
grad = torch.autograd.grad(result[rank].sum(), x)[0]
assert torch.allclose(grad, a * torch.ones_like(x))


def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
x = (rank + 1) * torch.ones(rank + 1, 2 - rank, requires_grad=True)

# random linear transformation, it should really not matter what we do here
a, b = torch.randn(1), torch.randn(1)
y = a * x + b # gradient of y w.r.t. x is a

result = gather_all_tensors(y)
assert len(result) == worldsize
grad = torch.autograd.grad(result[rank].sum(), x)[0]
assert torch.allclose(grad, a * torch.ones_like(x))


@pytest.mark.DDP()
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.")
@pytest.mark.parametrize(
"process", [_test_ddp_gather_all_autograd_same_shape, _test_ddp_gather_all_autograd_different_shape]
)
def test_ddp_autograd(process):
"""Test ddp functions for autograd compatibility."""
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_non_contiguous_tensors(rank):
class DummyCatMetric(Metric):
full_state_update = True
Expand Down
15 changes: 14 additions & 1 deletion tests/unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,17 @@ def use_deterministic_algorithms(): # noqa: PT004


def setup_ddp(rank, world_size):
"""Initialize ddp environment."""
"""Initialize ddp environment.
If a particular test relies on the order of the processes in the pool to be [0, 1, 2, ...], then this function
should be called inside the test to ensure that the processes are initialized in the same order they are used in
the tests.
Args:
rank: the rank of the process
world_size: the number of processes
"""
global CURRENT_PORT

os.environ["MASTER_ADDR"] = "localhost"
Expand All @@ -55,6 +65,9 @@ def setup_ddp(rank, world_size):
if CURRENT_PORT > MAX_PORT:
CURRENT_PORT = START_PORT

if torch.distributed.group.WORLD is not None: # if already initialized, destroy the process group
torch.distributed.destroy_process_group()

if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)

Expand Down

0 comments on commit d3894e1

Please sign in to comment.