Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable autograd graph to propagate after multi-device syncing for loss functions in ddp #2754

Merged
merged 35 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8122e9f
propagate rank result to gathered result for autograd compatibility
cw-tan Sep 17, 2024
c2b6d19
add unittest for dpp gather autograd compatibility
cw-tan Sep 17, 2024
7dec9b4
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 9, 2024
d1e64e4
changelog
SkafteNicki Oct 9, 2024
fc366b8
add to docs
SkafteNicki Oct 9, 2024
59c9ced
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
dab2bd9
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 9, 2024
6f188a8
Apply suggestions from code review
SkafteNicki Oct 9, 2024
ebb4f4c
add missing import
SkafteNicki Oct 9, 2024
05b6e96
remove redundant functions
SkafteNicki Oct 9, 2024
86aceb6
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 10, 2024
f854bf2
try no_grad for the all gather
cw-tan Oct 10, 2024
25ffff2
retry with all tested torch versions
cw-tan Oct 11, 2024
e82c70e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 11, 2024
b5f285d
incorporate trials
cw-tan Oct 11, 2024
4e1e836
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 12, 2024
5b9f79d
Merge branch 'master' into all_gather_ad
Borda Oct 14, 2024
5164e1d
Merge branch 'master' into all_gather_ad
Borda Oct 14, 2024
91cff5e
lint
Borda Oct 14, 2024
8fdc912
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 15, 2024
4c13d6c
try adding contiguous
cw-tan Oct 15, 2024
74bf6b2
Merge branch 'master' into all_gather_ad
Borda Oct 16, 2024
00935f1
Merge branch 'master' into all_gather_ad
cw-tan Oct 18, 2024
150251c
try using float64
cw-tan Oct 18, 2024
70967ba
Merge branch 'master' into all_gather_ad
cw-tan Oct 18, 2024
9b17d6f
try using random numbers
cw-tan Oct 19, 2024
6e476ea
Merge branch 'master' into all_gather_ad
Borda Oct 21, 2024
c20f07c
Merge branch 'master' into all_gather_ad
Borda Oct 22, 2024
2033395
Merge branch 'master' into all_gather_ad
Borda Oct 23, 2024
a424412
Merge branch 'master' into all_gather_ad
Borda Oct 30, 2024
8b263ae
fix changelog
SkafteNicki Oct 31, 2024
8d2c27e
small changes to distributed
SkafteNicki Oct 31, 2024
48e699b
tests
SkafteNicki Oct 31, 2024
ea37534
Merge branch 'master' into all_gather_ad
SkafteNicki Oct 31, 2024
5f29c4d
caution
Borda Oct 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754))


### Changed

- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))
Expand Down
5 changes: 5 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,11 @@ In practice this means that:

A functional metric is differentiable if its corresponding modular metric is differentiable.

.. caution::
For PyTorch versions 2.1 or higher, differentiation in DDP mode is enabled, allowing autograd graph
propagation after the ``all_gather`` operation. This is useful for synchronizing metrics used as
loss functions in a DDP setting.

***************************************
Metrics and hyperparameter optimization
***************************************
Expand Down
32 changes: 19 additions & 13 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,11 @@ def class_reduce(


def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
with torch.no_grad():
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
Borda marked this conversation as resolved.
Show resolved Hide resolved
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return gathered_result


Expand Down Expand Up @@ -133,15 +136,18 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
return _simple_gather_all_tensors(result, group, world_size)

# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
with torch.no_grad():
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
Borda marked this conversation as resolved.
Show resolved Hide resolved
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return gathered_result
42 changes: 42 additions & 0 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from unittests import NUM_PROCESSES, USE_PYTEST_POOL
from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from unittests.conftest import setup_ddp

seed_all(42)

Expand Down Expand Up @@ -105,6 +106,47 @@ def test_ddp(process):
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
x = (rank + 1) * torch.ones(10, requires_grad=True)

# random linear transformation, it should really not matter what we do here
a, b = torch.randn(1), torch.randn(1)
y = a * x + b # gradient of y w.r.t. x is a

result = gather_all_tensors(y)
assert len(result) == worldsize
grad = torch.autograd.grad(result[rank].sum(), x)[0]
assert torch.allclose(grad, a * torch.ones_like(x))


def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
x = (rank + 1) * torch.ones(rank + 1, 2 - rank, requires_grad=True)

# random linear transformation, it should really not matter what we do here
a, b = torch.randn(1), torch.randn(1)
y = a * x + b # gradient of y w.r.t. x is a

result = gather_all_tensors(y)
assert len(result) == worldsize
grad = torch.autograd.grad(result[rank].sum(), x)[0]
assert torch.allclose(grad, a * torch.ones_like(x))


@pytest.mark.DDP()
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.")
@pytest.mark.parametrize(
"process", [_test_ddp_gather_all_autograd_same_shape, _test_ddp_gather_all_autograd_different_shape]
)
def test_ddp_autograd(process):
"""Test ddp functions for autograd compatibility."""
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_non_contiguous_tensors(rank):
class DummyCatMetric(Metric):
full_state_update = True
Expand Down
15 changes: 14 additions & 1 deletion tests/unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,17 @@ def use_deterministic_algorithms(): # noqa: PT004


def setup_ddp(rank, world_size):
"""Initialize ddp environment."""
"""Initialize ddp environment.

If a particular test relies on the order of the processes in the pool to be [0, 1, 2, ...], then this function
should be called inside the test to ensure that the processes are initialized in the same order they are used in
the tests.

Args:
rank: the rank of the process
world_size: the number of processes

"""
global CURRENT_PORT

os.environ["MASTER_ADDR"] = "localhost"
Expand All @@ -55,6 +65,9 @@ def setup_ddp(rank, world_size):
if CURRENT_PORT > MAX_PORT:
CURRENT_PORT = START_PORT

if torch.distributed.group.WORLD is not None: # if already initialized, destroy the process group
torch.distributed.destroy_process_group()

if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)

Expand Down
Loading