Skip to content

Commit

Permalink
Sync loss tracker for Pretrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 7, 2024
1 parent 9fe42c4 commit 4c2387f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,10 +977,11 @@ def run_training_epoch(self):
self.loss_tracker['lrs'].append(self.scheduler.get_last_lr()[0])

n_batches += 1
# Divide cumulative loss by number of batches-- sli inaccurate because last batch is different size
self.loss_tracker.update_losses([l / n_batches for l in losses], type='loss', epoch=self.current_epoch)
losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
self.loss_tracker.update_losses(losses, type='loss', epoch=self.current_epoch)

self.save_model_and_states_checkpoint()
if self.rank == 0:
self.save_model_and_states_checkpoint()

if self.configs.post_training_epoch_hook is not None:
self.configs.post_training_epoch_hook()
Expand All @@ -997,8 +998,9 @@ def run_validation(self):
logger.warning('Validation set might be too small that at least 1 rank did not get any validation data.')
n_batches = np.max([n_batches, 1])
last_best_val_loss = self.loss_tracker['best_val_loss']
is_best = self.loss_tracker.update_losses([l / n_batches for l in losses],
epoch=self.current_epoch, type='val_loss')

losses = [self.communicate_value_across_ranks(l / n_batches, mode='average') for l in losses]
is_best = self.loss_tracker.update_losses(losses, epoch=self.current_epoch, type='val_loss')
if self.rank == 0:
self.write_training_info()

Expand Down

0 comments on commit 4c2387f

Please sign in to comment.