Skip to content

Commit

Permalink
fix: nesting scheduler stepping mechanism only if something is passed
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Jul 1, 2024
1 parent 7f726b9 commit b888ee8
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2009,14 +2009,15 @@ def training_step(
# step learning rate schedulers at the end of epochs
if self.trainer.is_last_batch:
schedulers = self.lr_schedulers()
if not isinstance(schedulers, list):
schedulers = [schedulers]
for s in schedulers:
# for schedulers that need a metric
if isinstance(s, lr_scheduler.ReduceLROnPlateau):
s.step(loss, self.current_epoch)
else:
s.step(epoch=self.current_epoch)
if schedulers is not None:
if not isinstance(schedulers, list):
schedulers = [schedulers]
for s in schedulers:
# for schedulers that need a metric
if isinstance(s, lr_scheduler.ReduceLROnPlateau):
s.step(loss, self.current_epoch)
else:
s.step(epoch=self.current_epoch)
if self.hparams.log_embeddings and "embeddings" in batch:
self._log_embedding(batch["embeddings"])
return loss_dict
Expand Down

0 comments on commit b888ee8

Please sign in to comment.