Skip to content

Commit

Permalink
small changes to distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Oct 31, 2024
1 parent 8b263ae commit 8d2c27e
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> L
with torch.no_grad():
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
# if _TORCH_GREATER_EQUAL_2_1:
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result

Expand Down Expand Up @@ -149,7 +148,6 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
# to propagate autograd graph from local rank (achieves intended effect for torch> 2.0)
# if _TORCH_GREATER_EQUAL_2_1:
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result

0 comments on commit 8d2c27e

Please sign in to comment.