|
22 | 22 |
|
23 | 23 | if torch.distributed.is_available():
|
24 | 24 | from torch.distributed import ReduceOp
|
| 25 | + from torch.distributed import group |
25 | 26 | else:
|
26 | 27 | class ReduceOp:
|
27 | 28 | SUM = None
|
28 | 29 |
|
| 30 | + class group: |
| 31 | + WORLD = None |
| 32 | + |
29 | 33 |
|
30 | 34 | def rank_zero_only(fn):
|
31 | 35 |
|
@@ -155,3 +159,54 @@ def sync_ddp(
|
155 | 159 | result = result / torch.distributed.get_world_size(group)
|
156 | 160 |
|
157 | 161 | return result
|
| 162 | + |
| 163 | + |
| 164 | +class AllGatherGrad(torch.autograd.Function): |
| 165 | + @staticmethod |
| 166 | + def forward(ctx, tensor, group=group.WORLD): |
| 167 | + ctx.group = group |
| 168 | + |
| 169 | + gathered_tensor = [ |
| 170 | + torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size()) |
| 171 | + ] |
| 172 | + |
| 173 | + torch.distributed.all_gather(gathered_tensor, tensor, group=group) |
| 174 | + gathered_tensor = torch.stack(gathered_tensor, dim=0) |
| 175 | + |
| 176 | + return gathered_tensor |
| 177 | + |
| 178 | + @staticmethod |
| 179 | + def backward(ctx, *grad_output): |
| 180 | + grad_output = torch.cat(grad_output) |
| 181 | + |
| 182 | + torch.distributed.all_reduce( |
| 183 | + grad_output, |
| 184 | + op=torch.distributed.ReduceOp.SUM, |
| 185 | + async_op=False, |
| 186 | + group=ctx.group |
| 187 | + ) |
| 188 | + |
| 189 | + return grad_output[torch.distributed.get_rank()] |
| 190 | + |
| 191 | + |
| 192 | +def all_gather_ddp_if_available( |
| 193 | + tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False |
| 194 | +) -> torch.Tensor: |
| 195 | + """ |
| 196 | + Function to gather a tensor from several distributed processes |
| 197 | +
|
| 198 | + Args: |
| 199 | + tensor: tensor of shape (batch, ...) |
| 200 | + group: the process group to gather results from. Defaults to all processes (world) |
| 201 | + sync_grads: flag that allows users to synchronize gradients for all_gather op |
| 202 | +
|
| 203 | + Return: |
| 204 | + A tensor of shape (world_size, batch, ...) |
| 205 | + """ |
| 206 | + if torch.distributed.is_available() and torch.distributed.is_initialized(): |
| 207 | + if sync_grads: |
| 208 | + return AllGatherGrad.apply(tensor, group) |
| 209 | + else: |
| 210 | + with torch.no_grad: |
| 211 | + return AllGatherGrad.apply(tensor, group) |
| 212 | + return tensor |
0 commit comments