You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a setup with torch.Lightning where I'm using custom torchmetrics.Metric as loss function contributions. Now I want to be able to do it with ddp by setting dist_sync_on_step=True, but the gradients are not propagated during the all_gather. All I want is for the tensor on the current process to have its autograd graph kept for the backward pass after the syncing operations. I've only just began looking into distributed stuff in torch, so I'm not experienced in these matters. But following the forward() call of Metric (at each training batch step), it then calls _forward_reduce_state_update(), which calls the compute() function wrapped by _wrap_compute(), which would do sync(), which finally calls _sync_dist(). And it looks like the syncing uses torchmetrics.utilities.distributed.gather_all_tensors.
I just wanted to ask if it is possible to achieve what I want by modiyfing _simple_gather_all_tensors (here)? _simple_gather_all_tensors presented here for reference.
I'm guessing that result still carries the autograd graph. My naive hope is that we can just update gathered_result with the input result (carrying the autograd graph) to achieve the desired effect.
For context, my use case is such that batches can have very inhomogeneous numels, so each device could have error tensors with very different numels such that taking a mean of MeanSquaredErrors may not be ideal. Ideally, if the syncing holds the autograd graph, the per-step loss would be the "true" metric as per its definition and the gradients would be consistent with that definition (so syncing is done once for for each loss metric contribution, and once for the backward at each training step, I think).
Thank you!
The text was updated successfully, but these errors were encountered:
I have a setup with
torch.Lightning
where I'm using customtorchmetrics.Metric
as loss function contributions. Now I want to be able to do it withddp
by settingdist_sync_on_step=True
, but the gradients are not propagated during theall_gather
. All I want is for the tensor on the current process to have its autograd graph kept for the backward pass after the syncing operations. I've only just began looking into distributed stuff intorch
, so I'm not experienced in these matters. But following theforward()
call ofMetric
(at each training batch step), it then calls_forward_reduce_state_update()
, which calls thecompute()
function wrapped by_wrap_compute()
, which would dosync()
, which finally calls_sync_dist()
. And it looks like the syncing usestorchmetrics.utilities.distributed.gather_all_tensors
.I just wanted to ask if it is possible to achieve what I want by modiyfing
_simple_gather_all_tensors
(here)?_simple_gather_all_tensors
presented here for reference.I'm guessing that
result
still carries the autograd graph. My naive hope is that we can just updategathered_result
with the inputresult
(carrying the autograd graph) to achieve the desired effect.For context, my use case is such that batches can have very inhomogeneous
numel
s, so each device could have error tensors with very differentnumel
s such that taking a mean ofMeanSquaredError
s may not be ideal. Ideally, if the syncing holds the autograd graph, the per-step loss would be the "true" metric as per its definition and the gradients would be consistent with that definition (so syncing is done once for for each loss metric contribution, and once for the backward at each training step, I think).Thank you!
The text was updated successfully, but these errors were encountered: