Skip to content

Commit 127454a

Browse files
ananyahjha93Borda
andauthored
All gatherwith grads (Lightning-AI#5012)
* all_gather * ddp * horovod * grad tests * fixed ddp * ddp fixed, removed tpu, horovod for now * changelog * windows fix * windows fix * removed batch from ctx * all_gather * ddp * horovod * grad tests * fixed ddp * ddp fixed, removed tpu, horovod for now * changelog * windows fix * windows fix * removed batch from ctx * removed code duplication * merge Co-authored-by: Jirka Borovec <[email protected]>
1 parent ee9b3fe commit 127454a

11 files changed

+211
-2
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
## Unreleased
88

9+
### Added
10+
11+
- Added `all_gather` method to `LightningModule` which allows gradient based tensor synchronizations for use-cases such as negative sampling. ([#5012](https://github.com/PyTorchLightning/pytorch-lightning/pull/5012))
12+
913
### Fixed
1014

1115
- Fixed `LoggerConnector` to have logged metrics on root device in DP ([#4138](https://github.com/PyTorchLightning/pytorch-lightning/pull/4138))

pytorch_lightning/accelerators/accelerator.py

+14
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ def sync_tensor(self,
172172
"""
173173
raise NotImplementedError()
174174

175+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
176+
"""
177+
Function to gather a tensor from several distributed processes
178+
179+
Args:
180+
tensor: tensor of shape (batch, ...)
181+
group: the process group to gather results from. Defaults to all processes (world)
182+
sync_grads: flag that allows users to synchronize gradients for all_gather op
183+
184+
Return:
185+
A tensor of shape (world_size, batch, ...)
186+
"""
187+
raise NotImplementedError()
188+
175189
def optimizer_state(self, optimizer: Optimizer) -> dict:
176190
"""
177191
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom

pytorch_lightning/accelerators/ddp2_accelerator.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytorch_lightning.distributed.dist import LightningDistributed
2626
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2727
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
28-
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
28+
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
2929

3030
if HYDRA_AVAILABLE:
3131
from hydra.core.hydra_config import HydraConfig
@@ -234,6 +234,20 @@ def sync_tensor(self,
234234
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
235235
return sync_ddp_if_available(tensor, group, reduce_op)
236236

237+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
238+
"""
239+
Function to gather a tensor from several distributed processes
240+
241+
Args:
242+
tensor: tensor of shape (batch, ...)
243+
group: the process group to gather results from. Defaults to all processes (world)
244+
sync_grads: flag that allows users to synchronize gradients for all_gather op
245+
246+
Return:
247+
A tensor of shape (world_size, batch, ...)
248+
"""
249+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
250+
237251
def get_reference_model(self, model) -> LightningModule:
238252
return self.ddp_plugin.get_model_from_plugin(model)
239253

pytorch_lightning/accelerators/ddp_accelerator.py

+15
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytorch_lightning.distributed.dist import LightningDistributed
3030
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
3131
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
32+
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
3233
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
3334
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3435
from pytorch_lightning.utilities.seed import seed_everything
@@ -333,6 +334,20 @@ def sync_tensor(self,
333334
"""
334335
return sync_ddp_if_available(tensor, group, reduce_op)
335336

337+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
338+
"""
339+
Function to gather a tensor from several distributed processes
340+
341+
Args:
342+
tensor: tensor of shape (batch, ...)
343+
group: the process group to gather results from. Defaults to all processes (world)
344+
sync_grads: flag that allows users to synchronize gradients for all_gather op
345+
346+
Return:
347+
A tensor of shape (world_size, batch, ...)
348+
"""
349+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
350+
336351
def get_reference_model(self, model) -> LightningModule:
337352
return self.ddp_plugin.get_model_from_plugin(model)
338353

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

+15
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
rank_zero_only,
3232
rank_zero_warn,
3333
sync_ddp_if_available,
34+
all_gather_ddp_if_available,
3435
)
3536

3637
if HYDRA_AVAILABLE:
@@ -261,6 +262,20 @@ def sync_tensor(self,
261262
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
262263
return sync_ddp_if_available(tensor, group, reduce_op)
263264

265+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
266+
"""
267+
Function to gather a tensor from several distributed processes
268+
269+
Args:
270+
tensor: tensor of shape (batch, ...)
271+
group: the process group to gather results from. Defaults to all processes (world)
272+
sync_grads: flag that allows users to synchronize gradients for all_gather op
273+
274+
Return:
275+
A tensor of shape (world_size, batch, ...)
276+
"""
277+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
278+
264279
def get_reference_model(self, model) -> LightningModule:
265280
return self.ddp_plugin.get_model_from_plugin(model)
266281

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytorch_lightning.distributed.dist import LightningDistributed
2626
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2727
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
28-
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
28+
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
2929

3030
if HYDRA_AVAILABLE:
3131
from hydra.core.hydra_config import HydraConfig
@@ -225,6 +225,20 @@ def sync_tensor(self,
225225
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
226226
return sync_ddp_if_available(tensor, group, reduce_op)
227227

228+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
229+
"""
230+
Function to gather a tensor from several distributed processes
231+
232+
Args:
233+
tensor: tensor of shape (batch, ...)
234+
group: the process group to gather results from. Defaults to all processes (world)
235+
sync_grads: flag that allows users to synchronize gradients for all_gather op
236+
237+
Return:
238+
A tensor of shape (world_size, batch, ...)
239+
"""
240+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
241+
228242
def get_reference_model(self, model) -> LightningModule:
229243
return self.ddp_plugin.get_model_from_plugin(model)
230244

pytorch_lightning/accelerators/ddp_spawn_accelerator.py

+15
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
rank_zero_only,
3535
rank_zero_warn,
3636
sync_ddp_if_available,
37+
all_gather_ddp_if_available,
3738
)
3839
from pytorch_lightning.utilities.seed import seed_everything
3940

@@ -293,6 +294,20 @@ def sync_tensor(self,
293294
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
294295
return sync_ddp_if_available(tensor, group, reduce_op)
295296

297+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
298+
"""
299+
Function to gather a tensor from several distributed processes
300+
301+
Args:
302+
tensor: tensor of shape (batch, ...)
303+
group: the process group to gather results from. Defaults to all processes (world)
304+
sync_grads: flag that allows users to synchronize gradients for all_gather op
305+
306+
Return:
307+
A tensor of shape (world_size, batch, ...)
308+
"""
309+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
310+
296311
def get_reference_model(self, model) -> LightningModule:
297312
return self.ddp_plugin.get_model_from_plugin(model)
298313

pytorch_lightning/core/lightning.py

+18
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,24 @@ def __auto_choose_log_on_epoch(self, on_epoch):
365365

366366
return on_epoch
367367

368+
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
369+
r"""
370+
Allows users to call ``self.all_gather()`` from the LightningModule, thus making
371+
the ```all_gather``` operation accelerator agnostic.
372+
373+
```all_gather``` is a function provided by accelerators to gather a tensor from several
374+
distributed processes
375+
376+
Args:
377+
tensor: tensor of shape (batch, ...)
378+
group: the process group to gather results from. Defaults to all processes (world)
379+
sync_grads: flag that allows users to synchronize gradients for all_gather op
380+
381+
Return:
382+
A tensor of shape (world_size, batch, ...)
383+
"""
384+
return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads)
385+
368386
def forward(self, *args, **kwargs):
369387
r"""
370388
Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define

pytorch_lightning/utilities/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from pytorch_lightning.utilities.apply_func import move_data_to_device
2323
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, rank_zero_warn
24+
from pytorch_lightning.utilities.distributed import AllGatherGrad
2425
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
2526
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils
2627

pytorch_lightning/utilities/distributed.py

+55
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222

2323
if torch.distributed.is_available():
2424
from torch.distributed import ReduceOp
25+
from torch.distributed import group
2526
else:
2627
class ReduceOp:
2728
SUM = None
2829

30+
class group:
31+
WORLD = None
32+
2933

3034
def rank_zero_only(fn):
3135

@@ -155,3 +159,54 @@ def sync_ddp(
155159
result = result / torch.distributed.get_world_size(group)
156160

157161
return result
162+
163+
164+
class AllGatherGrad(torch.autograd.Function):
165+
@staticmethod
166+
def forward(ctx, tensor, group=group.WORLD):
167+
ctx.group = group
168+
169+
gathered_tensor = [
170+
torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())
171+
]
172+
173+
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
174+
gathered_tensor = torch.stack(gathered_tensor, dim=0)
175+
176+
return gathered_tensor
177+
178+
@staticmethod
179+
def backward(ctx, *grad_output):
180+
grad_output = torch.cat(grad_output)
181+
182+
torch.distributed.all_reduce(
183+
grad_output,
184+
op=torch.distributed.ReduceOp.SUM,
185+
async_op=False,
186+
group=ctx.group
187+
)
188+
189+
return grad_output[torch.distributed.get_rank()]
190+
191+
192+
def all_gather_ddp_if_available(
193+
tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False
194+
) -> torch.Tensor:
195+
"""
196+
Function to gather a tensor from several distributed processes
197+
198+
Args:
199+
tensor: tensor of shape (batch, ...)
200+
group: the process group to gather results from. Defaults to all processes (world)
201+
sync_grads: flag that allows users to synchronize gradients for all_gather op
202+
203+
Return:
204+
A tensor of shape (world_size, batch, ...)
205+
"""
206+
if torch.distributed.is_available() and torch.distributed.is_initialized():
207+
if sync_grads:
208+
return AllGatherGrad.apply(tensor, group)
209+
else:
210+
with torch.no_grad:
211+
return AllGatherGrad.apply(tensor, group)
212+
return tensor
+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import pytest
3+
import sys
4+
import torch
5+
import torch.nn as nn
6+
7+
from pytorch_lightning.utilities import AllGatherGrad
8+
9+
10+
def setup_ddp(rank, world_size):
11+
""" Setup ddp enviroment """
12+
os.environ["MASTER_ADDR"] = "localhost"
13+
os.environ["MASTER_PORT"] = "8088"
14+
15+
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
16+
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
17+
18+
19+
def _test_all_gather_ddp(rank, world_size):
20+
setup_ddp(rank, world_size)
21+
22+
tensor1 = torch.ones(8, requires_grad=True)
23+
tensor2 = torch.ones((8, 16, 32), requires_grad=True)
24+
25+
tensor1_gathered = AllGatherGrad.apply(tensor1)
26+
tensor2_gathered = AllGatherGrad.apply(tensor2)
27+
28+
tensor1_gathered = tensor1_gathered * rank
29+
tensor2_gathered = tensor2_gathered * rank
30+
31+
tensor1_gathered.sum().backward()
32+
tensor2_gathered.sum().backward()
33+
34+
grad1 = torch.zeros_like(tensor1.grad).fill_(torch.arange(world_size).sum().float())
35+
grad2 = torch.zeros_like(tensor2.grad).fill_(torch.arange(world_size).sum().float())
36+
37+
assert torch.allclose(grad1, tensor1.grad)
38+
assert torch.allclose(grad2, tensor2.grad)
39+
40+
41+
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
42+
def test_all_gather_ddp():
43+
world_size = 3
44+
torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size)

0 commit comments

Comments
 (0)