Skip to content

Commit

Permalink
update save best state callback
Browse files Browse the repository at this point in the history
  • Loading branch information
adamgayoso committed Jan 21, 2021
1 parent b157f75 commit 156287a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 4 additions & 4 deletions scvi/lightning/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ class SaveBestState(Callback):

def __init__(
self,
monitor="val_loss",
monitor: str = "elbo_validation",
mode: str = "min",
verbose=False,
mode="auto",
period=1,
):
super().__init__()
Expand Down Expand Up @@ -100,8 +100,8 @@ def on_epoch_end(self, trainer, pl_module):

if self.verbose:
rank_zero_info(
f"\nEpoch {trainer.current_epoch:05d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_metric_val:0.5f})"
f"\nEpoch {trainer.current_epoch:05d}: {self.monitor} reached."
f" Model best state updated."
)

def on_train_end(self, trainer, pl_module):
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from scvi.data import synthetic_iid
from scvi.model import SCVI
from scvi.lightning._callbacks import SaveBestState


def test_save_best_state_callback(save_path):

n_latent = 5
adata = synthetic_iid()
model = SCVI(adata, n_latent=n_latent)
callbacks = [SaveBestState(verbose=True)]
model.train(3, check_val_every_n_epoch=1, train_size=0.5, callbacks=callbacks)

0 comments on commit 156287a

Please sign in to comment.