Proper way to log things when using DDP #6501
-
| 
         Hi, I was wondering what is the proper way of logging metrics when using DDP. I noticed that if I want to print something inside  
 I understand that I can solve the printing by checking  Here is a code snippet from my use case. I would like to be able to report f1, precision and recall on the entire validation dataset and I am wondering what is the correct way of doing it when using DDP.     def _process_epoch_outputs(self,
                               outputs: List[Dict[str, Any]]
                               ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Creates and returns tensors containing all labels and predictions
        Goes over the outputs accumulated from every batch, detaches the
        necessary tensors and stacks them together.
        Args:
            outputs (List[Dict])
        """
        all_labels = []
        all_predictions = []
        for output in outputs:
            for labels in output['labels'].detach():
                all_labels.append(labels)
            for predictions in output['predictions'].detach():
                all_predictions.append(predictions)
        all_labels = torch.stack(all_labels).long().cpu()
        all_predictions = torch.stack(all_predictions).cpu()
        return all_predictions, all_labels
    def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
        """Logs f1, precision and recall on the validation set."""
        if self.global_rank == 0:
            print(f'Validation Epoch: {self.current_epoch}')
        predictions, labels = self._process_epoch_outputs(outputs)
        for i, name in enumerate(self.label_columns):
            f1, prec, recall, t = metrics.get_f1_prec_recall(predictions[:, i],
                                                             labels[:, i],
                                                             threshold=None)
            self.logger.experiment.add_scalar(f'{name}_f1/Val',
                                              f1,
                                              self.current_epoch)
            self.logger.experiment.add_scalar(f'{name}_Precision/Val',
                                              prec,
                                              self.current_epoch)
            self.logger.experiment.add_scalar(f'{name}_Recall/Val',
                                              recall,
                                              self.current_epoch)
            if self.global_rank == 0:
                print((f'F1: {f1}, Precision: {prec}, '
                       f'Recall: {recall}, Threshold {t}')) | 
  
Beta Was this translation helpful? Give feedback.
Replies: 5 comments 28 replies
-
| 
         I have the same question, and have not been able to get sufficient clarity from the docs about how logging works during distributed training. I found the suggestion to use the   | 
  
Beta Was this translation helpful? Give feedback.
-
| 
         I have the same problem. I managed to log the synced metric by calling metric.compute(). But the value is not identical with the Checkpoint callback. Details please find in #6352 And there is a related issue:  | 
  
Beta Was this translation helpful? Give feedback.
-
| 
         @williamFalcon any chance someone can help us with this?  | 
  
Beta Was this translation helpful? Give feedback.
-
| 
         @edenafek could you please take a look at the above issue?  | 
  
Beta Was this translation helpful? Give feedback.
-
| 
         Hi all, 
 No, it is called by all processes 
 Here is the essential code: 
 Recommended is using either the  from pytorch_lightning.utilities import rank_zero_infoor use the  
 Our own metrics have custom synchronization going on. Any metric will automatically synchronize between different processes whenever  
 Using  Not sure this answers all questions.  | 
  
Beta Was this translation helpful? Give feedback.
Hi all,
Sorry we have not got back to you in time, let me try to answer some of your questions:
validation_epoch_endonly called on rank 0?No, it is called by all processes
sync_distflag do:Here is the essential code:
https://github.com/PyTorchLightning/pytorch-lightning/blob/a72a7992a283f2eb5183d129a8cf6466903f1dc8/pytorch_lightning/core/step_result.py#L108-L115
If
sync_dist=Truethen it will as default call thesync_ddpfunction which will sum the value across all processes usingtorch.distributed.all_reducehttps://github.com/PyTorchLightning/pytorch-lightning/blob/a72a7992a283f2eb5183d129a8cf6466903f1dc8/pytorch_lightning/utilities/distributed.py#L120
Use this …