Skip to content

Commit

Permalink
Merge pull request #1025 from YosefLab/scanvi_heuristic
Browse files Browse the repository at this point in the history
if scanvi was pretrained, follow old heuristic
  • Loading branch information
adamgayoso authored Apr 12, 2021
2 parents d0c1d1c + 30fea2f commit 2aeeb1d
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
gene_likelihood,
)
self.init_params_ = self._get_init_params(locals())
self.was_pretrained = False

@classmethod
def from_scvi_model(
Expand Down Expand Up @@ -192,6 +193,7 @@ def from_scvi_model(
)
scvi_state_dict = scvi_model.module.state_dict()
scanvi_model.module.load_state_dict(scvi_state_dict, strict=False)
scanvi_model.was_pretrained = True

return scanvi_model

Expand Down Expand Up @@ -343,6 +345,9 @@ def train(
n_cells = self.adata.n_obs
max_epochs = np.min([round((20000 / n_cells) * 400), 400])

if self.was_pretrained:
max_epochs = int(np.min([10, np.max([2, round(max_epochs / 3.0)])]))

logger.info("Training for {} epochs.".format(max_epochs))

plan_kwargs = {} if plan_kwargs is None else plan_kwargs
Expand Down

0 comments on commit 2aeeb1d

Please sign in to comment.