diff --git a/tests/unittests/bases/test_ddp.py b/tests/unittests/bases/test_ddp.py index 5af7123e942..68e9286421b 100644 --- a/tests/unittests/bases/test_ddp.py +++ b/tests/unittests/bases/test_ddp.py @@ -172,6 +172,7 @@ def _test_ddp_gather_autograd_same_shape(rank: int, worldsize: int = NUM_PROCESS This test only considers tensors of the same shape across different ranks. Note that this test only works for torch>=2.0. + """ tensor = torch.ones(50, requires_grad=True) result = gather_all_tensors(tensor) @@ -198,6 +199,7 @@ def _test_ddp_gather_autograd_different_shape(rank: int, worldsize: int = NUM_PR This test considers tensors of different shapes across different ranks. Note that this test only works for torch>=2.0. + """ tensor = torch.ones(rank + 1, 2 - rank, requires_grad=True) result = gather_all_tensors(tensor)