diff --git a/scvi/lightning/_callbacks.py b/scvi/lightning/_callbacks.py index 6fcc13e8f5..337ec056c2 100644 --- a/scvi/lightning/_callbacks.py +++ b/scvi/lightning/_callbacks.py @@ -38,9 +38,9 @@ class SaveBestState(Callback): def __init__( self, - monitor="val_loss", + monitor: str = "elbo_validation", + mode: str = "min", verbose=False, - mode="auto", period=1, ): super().__init__() @@ -100,8 +100,8 @@ def on_epoch_end(self, trainer, pl_module): if self.verbose: rank_zero_info( - f"\nEpoch {trainer.current_epoch:05d}: {self.monitor} reached" - f" {current:0.5f} (best {self.best_model_metric_val:0.5f})" + f"\nEpoch {trainer.current_epoch:05d}: {self.monitor} reached." + f" Model best state updated." ) def on_train_end(self, trainer, pl_module): diff --git a/tests/models/test_lightning.py b/tests/models/test_lightning.py new file mode 100644 index 0000000000..542934b07f --- /dev/null +++ b/tests/models/test_lightning.py @@ -0,0 +1,12 @@ +from scvi.data import synthetic_iid +from scvi.model import SCVI +from scvi.lightning._callbacks import SaveBestState + + +def test_save_best_state_callback(save_path): + + n_latent = 5 + adata = synthetic_iid() + model = SCVI(adata, n_latent=n_latent) + callbacks = [SaveBestState(verbose=True)] + model.train(3, check_val_every_n_epoch=1, train_size=0.5, callbacks=callbacks)