diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index 38b5939dea..1ba1fad49b 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -345,8 +345,8 @@ def train( n_cells = self.adata.n_obs max_epochs = np.min([round((20000 / n_cells) * 400), 400]) - if self.was_pretrained: - max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) + if self.was_pretrained: + max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])])) logger.info("Training for {} epochs.".format(max_epochs))