Skip to content

Commit

Permalink
fix kl weight check for function based modules
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Oct 26, 2021
1 parent 75ea35d commit d8bd388
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyro
import pytorch_lightning as pl
import torch
from pyro.nn import PyroModule
from torch.optim.lr_scheduler import ReduceLROnPlateau

from scvi import _CONSTANTS
Expand Down Expand Up @@ -656,9 +657,13 @@ def __init__(
self.pyro_guide = self.module.guide
self.pyro_model = self.module.model

self.use_kl_weight = (
"kl_weight" in signature(self.pyro_model.forward).parameters
)
self.use_kl_weight = False
if isinstance(self.pyro_model, PyroModule):
self.use_kl_weight = (
"kl_weight" in signature(self.pyro_model.forward).parameters
)
elif callable(self.pyro_model):
self.use_kl_weight = "kl_weight" in signature(self.pyro_model).parameters

self.svi = pyro.infer.SVI(
model=self.pyro_model,
Expand Down

0 comments on commit d8bd388

Please sign in to comment.