Skip to content

Commit

Permalink
fix extend history when using non default logger
Browse files Browse the repository at this point in the history
  • Loading branch information
adamgayoso committed Jul 15, 2021
1 parent f9a0559 commit 4bc2b81
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions scvi/train/_trainrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ def __call__(self):

def _update_history(self):
# model is being further trained
if self.model.history_ is not None:
# this was set to true during first training session
if self.model.is_trained_ is True:
# if not using the default logger (e.g., tensorboard)
if not isinstance(self.model.history_, dict):
warnings.warn(
"Training history cannot be updated. Replacing old history with new history."
"Training history cannot be updated. Logger can be accessed from model.trainer.logger"
)
self.model.history_ = self.trainer.logger.history
return
else:
new_history = self.trainer.logger.history
Expand All @@ -110,4 +110,8 @@ def _update_history(self):
self.model.history_[key].index.name = val.index.name
else:
# set history_ attribute if it exists
self.model.history_ = self.trainer.logger.history
# other pytorch lightning loggers might not have history attr
try:
self.model.history_ = self.trainer.logger.history
except AttributeError:
self.history_ = None

0 comments on commit 4bc2b81

Please sign in to comment.