diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index c3d44e567c..57c3c43bc7 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -145,6 +145,7 @@ def __init__( gene_likelihood, ) self.init_params_ = self._get_init_params(locals()) + self.was_pretrained = False @classmethod def from_scvi_model( @@ -192,6 +193,7 @@ def from_scvi_model( ) scvi_state_dict = scvi_model.module.state_dict() scanvi_model.module.load_state_dict(scvi_state_dict, strict=False) + scanvi_model.was_pretrained = True return scanvi_model @@ -343,6 +345,9 @@ def train( n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) + if self.was_pretrained: + max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) + logger.info("Training for {} epochs.".format(max_epochs)) plan_kwargs = {} if plan_kwargs is None else plan_kwargs