Skip to content

Commit

Permalink
Backport PR #1446: Simplify changing the training plan for pyro (#1470)
Browse files Browse the repository at this point in the history
Co-authored-by: Vitalii Kleshchevnikov <[email protected]>
  • Loading branch information
meeseeksmachine and vitkl authored Mar 28, 2022
1 parent 4258a93 commit fc7688e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions scvi/model/base/_pyromixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def train(
batch_size: int = 128,
early_stopping: bool = False,
lr: Optional[float] = None,
training_plan: PyroTrainingPlan = PyroTrainingPlan,
plan_kwargs: Optional[dict] = None,
**trainer_kwargs,
):
Expand Down Expand Up @@ -92,8 +93,10 @@ def train(
lr
Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`).
Specifying optimiser via plan_kwargs overrides this choice of lr.
training_plan
Training plan :class:`~scvi.train.PyroTrainingPlan`.
plan_kwargs
Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to
Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**trainer_kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
Expand Down Expand Up @@ -123,7 +126,7 @@ def train(
batch_size=batch_size,
use_gpu=use_gpu,
)
training_plan = PyroTrainingPlan(pyro_module=self.module, **plan_kwargs)
training_plan = training_plan(self.module, **plan_kwargs)

es = "early_stopping"
trainer_kwargs[es] = (
Expand Down

0 comments on commit fc7688e

Please sign in to comment.