Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd with DDP #2745

Closed
cw-tan opened this issue Sep 16, 2024 · 2 comments · Fixed by #2754
Closed

Autograd with DDP #2745

cw-tan opened this issue Sep 16, 2024 · 2 comments · Fixed by #2754
Labels
enhancement New feature or request

Comments

@cw-tan
Copy link
Contributor

cw-tan commented Sep 16, 2024

I have a setup with torch.Lightning where I'm using custom torchmetrics.Metric as loss function contributions. Now I want to be able to do it with ddp by setting dist_sync_on_step=True, but the gradients are not propagated during the all_gather. All I want is for the tensor on the current process to have its autograd graph kept for the backward pass after the syncing operations. I've only just began looking into distributed stuff in torch, so I'm not experienced in these matters. But following the forward() call of Metric (at each training batch step), it then calls _forward_reduce_state_update(), which calls the compute() function wrapped by _wrap_compute(), which would do sync(), which finally calls _sync_dist(). And it looks like the syncing uses torchmetrics.utilities.distributed.gather_all_tensors.

I just wanted to ask if it is possible to achieve what I want by modiyfing _simple_gather_all_tensors (here)? _simple_gather_all_tensors presented here for reference.

def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
    gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
    torch.distributed.all_gather(gathered_result, result, group)
    return gathered_result

I'm guessing that result still carries the autograd graph. My naive hope is that we can just update gathered_result with the input result (carrying the autograd graph) to achieve the desired effect.

For context, my use case is such that batches can have very inhomogeneous numels, so each device could have error tensors with very different numels such that taking a mean of MeanSquaredErrors may not be ideal. Ideally, if the syncing holds the autograd graph, the per-step loss would be the "true" metric as per its definition and the gradients would be consistent with that definition (so syncing is done once for for each loss metric contribution, and once for the backward at each training step, I think).

Thank you!

@cw-tan cw-tan added the enhancement New feature or request label Sep 16, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@cw-tan
Copy link
Contributor Author

cw-tan commented Sep 16, 2024

It looks like adding gathered_result[torch.distributed.get_rank(group)] = result has worked for me so far, i.e.

def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
    gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
    torch.distributed.all_gather(gathered_result, result, group)
    gathered_result[torch.distributed.get_rank(group)] = result
    return gathered_result

found here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
1 participant