TorchMetrics, Pytorch Lightning and DataParallel #528
-
In the TorchMetrics in Pytorch Lightning section there is the following warning:
Is this true also in the case of distributed data parallel mode (ddp/ddp2)? My code reports correct metric values only if I follow the instructions above. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
No, in def training_step(self, batch, batch_idx):
...
self.metric.update(preds, target)
...
def training_epoch_end(self, outputs)
val = self.metric.compute() # this will sync the metric between devices
self.log("metric", val)
self.metric.reset() |
Beta Was this translation helpful? Give feedback.
No, in
ddp
it should not be necessary. You still need to sync the metric at some point between the different devices, which is automatically done whenmetric.compute()
is called. Therefore, something like this should still work: