Skip to content

Commit

Permalink
Merge pull request #1059 from vitkl/pyro-mixin
Browse files Browse the repository at this point in the history
Adding mixin classes for pyro training and posterior sampling
  • Loading branch information
adamgayoso authored Jun 19, 2021
2 parents 3a6ca06 + 52759bb commit bc415d8
Show file tree
Hide file tree
Showing 9 changed files with 770 additions and 50 deletions.
15 changes: 15 additions & 0 deletions docs/_templates/class_no_inherited.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@
.. autoclass:: {{ objname }}
:show-inheritance:

{% block attributes %}
{% if attributes %}
.. rubric:: Attributes

.. autosummary::
:toctree: .
{% for item in attributes %}
{%- if item not in inherited_members%}
~{{ fullname }}.{{ item }}
{%- endif -%}
{%- endfor %}
{% endif %}
{% endblock %}


{% block methods %}
{% if methods %}
.. rubric:: Methods
Expand Down
3 changes: 3 additions & 0 deletions docs/api/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ These classes should be used to construct user-facing model classes.
model.base.RNASeqMixin
model.base.ArchesMixin
model.base.UnsupervisedTrainingMixin
model.base.PyroSviTrainMixin
model.base.PyroSampleMixin
model.base.PyroJitGuideWarmup

Module
------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ sphinx = {version = ">=3.4", optional = true}
sphinx-autodoc-typehints = {version = "*", optional = true}
sphinx-gallery = {version = ">0.6", optional = true}
sphinx-tabs = {version = "*", optional = true}
sphinx_copybutton = {version = "*", optional = true}
sphinx_copybutton = {version = "<=0.3.1", optional = true}
torch = ">=1.8.0"
tqdm = ">=4.56.0"
typing_extensions = {version = "*", python = "<3.8"}
Expand Down
4 changes: 4 additions & 0 deletions scvi/model/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._archesmixin import ArchesMixin
from ._base_model import BaseModelClass
from ._pyromixin import PyroJitGuideWarmup, PyroSampleMixin, PyroSviTrainMixin
from ._rnamixin import RNASeqMixin
from ._training_mixin import UnsupervisedTrainingMixin
from ._vaemixin import VAEMixin
Expand All @@ -10,4 +11,7 @@
"RNASeqMixin",
"VAEMixin",
"UnsupervisedTrainingMixin",
"PyroSviTrainMixin",
"PyroSampleMixin",
"PyroJitGuideWarmup",
]
4 changes: 4 additions & 0 deletions scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _make_data_loader(
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
shuffle
Whether observations are shuffled each iteration though
data_loader_class
Class to use for data loader
data_loader_kwargs
Kwargs to the class-specific data loader class
"""
Expand Down Expand Up @@ -344,8 +346,10 @@ def load(
model.module.load_state_dict(model_state_dict)
except RuntimeError as err:
if isinstance(model.module, PyroBaseModuleClass):
old_history = model.history_
logger.info("Preparing underlying module for load")
model.train(max_steps=1)
model.history_ = old_history
pyro.clear_param_store()
model.module.load_state_dict(model_state_dict)
else:
Expand Down
Loading

0 comments on commit bc415d8

Please sign in to comment.