Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 31, 2024
1 parent 8d2c27e commit 48e699b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 64 deletions.
90 changes: 27 additions & 63 deletions tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from unittests import NUM_PROCESSES, USE_PYTEST_POOL
from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from unittests.conftest import setup_ddp

seed_all(42)

Expand Down Expand Up @@ -105,80 +106,43 @@ def test_ddp(process):
pytest.pool.map(process, range(NUM_PROCESSES))


def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks.
def _test_ddp_gather_all_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for same-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
x = (rank + 1) * torch.ones(10, requires_grad=True)

This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test only considers tensors of the same shape across different ranks.
# random linear transformation, it should really not matter what we do here
a, b = torch.randn(1), torch.randn(1)
y = a * x + b # gradient of y w.r.t. x is a

Note that this test only works for torch>=2.0.
"""
tensor = torch.randn(50, dtype=torch.float64, requires_grad=True)
result = gather_all_tensors(tensor)
result = gather_all_tensors(y)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
W = torch.randn_like(result[idx], requires_grad=False)
if idx == rank:
assert torch.allclose(result[idx], tensor)
scalar1 = scalar1 + torch.sum(tensor * W)
else:
scalar1 = scalar1 + torch.sum(result[idx] * W)
scalar2 = scalar2 + torch.sum(result[idx] * W)
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)


def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks.
This function tests that ``torchmetrics.utilities.distributed.gather_all_tensors`` works as intended in
preserving the local rank's autograd graph upon the gather. The function compares derivative values obtained
with the local rank results from the ``gather_all_tensors`` output and the original local rank tensor.
This test considers tensors of different shapes across different ranks.
Note that this test only works for torch>=2.0.
"""
tensor = torch.randn(rank + 1, 2 - rank, dtype=torch.float64, requires_grad=True)
result = gather_all_tensors(tensor)
grad = torch.autograd.grad(result[rank].sum(), x)[0]
assert torch.allclose(grad, a * torch.ones_like(x))


def _test_ddp_gather_all_autograd_different_shape(rank: int, worldsize: int = NUM_PROCESSES) -> None:
"""Test that ddp gather preserves local rank's autograd graph for differently-shaped tensors across ranks."""
setup_ddp(rank, worldsize)
x = (rank + 1) * torch.ones(rank + 1, 2 - rank, requires_grad=True)

# random linear transformation, it should really not matter what we do here
a, b = torch.randn(1), torch.randn(1)
y = a * x + b # gradient of y w.r.t. x is a

result = gather_all_tensors(y)
assert len(result) == worldsize
scalar1 = 0
scalar2 = 0
for idx in range(worldsize):
W = torch.randn_like(result[idx], requires_grad=False)
if idx == rank:
assert torch.allclose(result[idx], tensor)
scalar1 = scalar1 + torch.sum(tensor * W)
else:
scalar1 = scalar1 + torch.sum(result[idx] * W)
scalar2 = scalar2 + torch.sum(result[idx] * W)
gradient1 = torch.autograd.grad(scalar1, [tensor], retain_graph=True)[0]
gradient2 = torch.autograd.grad(scalar2, [tensor])[0]
assert torch.allclose(gradient1, gradient2)
grad = torch.autograd.grad(result[rank].sum(), x)[0]
assert torch.allclose(grad, a * torch.ones_like(x))


@pytest.mark.DDP()
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.skipif(not USE_PYTEST_POOL, reason="DDP pool is not available.")
# @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.parametrize(
"process",
[
_test_ddp_gather_autograd_same_shape,
_test_ddp_gather_autograd_different_shape,
],
)
@pytest.mark.parametrize(
"index",
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"process", [_test_ddp_gather_all_autograd_same_shape, _test_ddp_gather_all_autograd_different_shape]
)
def test_ddp_autograd(process, index):
def test_ddp_autograd(process):
"""Test ddp functions for autograd compatibility."""
pytest.pool.map(process, range(NUM_PROCESSES))

Expand Down
15 changes: 14 additions & 1 deletion tests/unittests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,17 @@ def use_deterministic_algorithms(): # noqa: PT004


def setup_ddp(rank, world_size):
"""Initialize ddp environment."""
"""Initialize ddp environment.
If a particular test relies on the order of the processes in the pool to be [0, 1, 2, ...], then this function
should be called inside the test to ensure that the processes are initialized in the same order they are used in
the tests.
Args:
rank: the rank of the process
world_size: the number of processes
"""
global CURRENT_PORT

os.environ["MASTER_ADDR"] = "localhost"
Expand All @@ -55,6 +65,9 @@ def setup_ddp(rank, world_size):
if CURRENT_PORT > MAX_PORT:
CURRENT_PORT = START_PORT

if torch.distributed.group.WORLD is not None: # if already initialized, destroy the process group
torch.distributed.destroy_process_group()

if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)

Expand Down

0 comments on commit 48e699b

Please sign in to comment.