diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index a5beb19..851c24f 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -977,10 +977,11 @@ def run_training_epoch(self): self.loss_tracker['lrs'].append(self.scheduler.get_last_lr()[0]) n_batches += 1 - # Divide cumulative loss by number of batches-- sli inaccurate because last batch is different size - self.loss_tracker.update_losses([l / n_batches for l in losses], type='loss', epoch=self.current_epoch) + losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] + self.loss_tracker.update_losses(losses, type='loss', epoch=self.current_epoch) - self.save_model_and_states_checkpoint() + if self.rank == 0: + self.save_model_and_states_checkpoint() if self.configs.post_training_epoch_hook is not None: self.configs.post_training_epoch_hook() @@ -997,8 +998,9 @@ def run_validation(self): logger.warning('Validation set might be too small that at least 1 rank did not get any validation data.') n_batches = np.max([n_batches, 1]) last_best_val_loss = self.loss_tracker['best_val_loss'] - is_best = self.loss_tracker.update_losses([l / n_batches for l in losses], - epoch=self.current_epoch, type='val_loss') + + losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses] + is_best = self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='val_loss') if self.rank == 0: self.write_training_info()