Skip to content

Commit

Permalink
Backport PR scverse#1235: 0.14.3 fixes: Do not filter out additional …
Browse files Browse the repository at this point in the history
…tensors on load, move docstring processors to public API (scverse#1240)
  • Loading branch information
justjhong authored Oct 22, 2021
1 parent a4901af commit 6584927
Show file tree
Hide file tree
Showing 27 changed files with 192 additions and 123 deletions.
5 changes: 3 additions & 2 deletions docs/api/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ These classes should be used to construct user-facing model classes.
model.base.PyroSviTrainMixin
model.base.PyroSampleMixin
model.base.PyroJitGuideWarmup

model.base.DifferentialComputation

Module
------

Expand Down Expand Up @@ -179,5 +180,5 @@ Utility functions used by scvi-tools.
:toctree: reference/
:nosignatures:

utils.DifferentialComputation
utils.track
utils.setup_anndata_dsp
1 change: 1 addition & 0 deletions docs/release_notes/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Version 0.14
.. toctree::
:maxdepth: 2

v0.14.3
v0.14.2
v0.14.1
v0.14.0
Expand Down
21 changes: 21 additions & 0 deletions docs/release_notes/v0.14.3.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
New in 0.14.3 (2021-10-19)
--------------------------

Bug fix.

Changes
~~~~~~~
- Bug fix to :func:`~scvi.model.base.BaseModelClass` to retain tensors registered by :class:`~scvi.data.register_tensor_from_anndata` (`#1235`_).
- Expose an instance of our ``DocstringProcessor`` to aid in documenting derived implementations of ``setup_anndata`` method (`#1235`_).

Contributors
~~~~~~~~~~~~
- `@adamgayoso`_
- `@jjhong922`_
- `@watiss`_

.. _`@adamgayoso`: https://github.com/adamgayoso
.. _`@jjhong922`: https://github.com/jjhong922
.. _`@watiss`: https://github.com/watiss

.. _`#1235` : https://github.com/YosefLab/scvi-tools/pull/1235
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ packages = [
{include = "scvi"},
]
readme = "README.md"
version = "0.14.2"
version = "0.14.3"

[tool.poetry.dependencies]
anndata = ">=0.7.5"
Expand Down
4 changes: 2 additions & 2 deletions scvi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._settings import settings

# this import needs to come after prior imports to prevent circular import
from . import data, model, external
from . import data, model, external, utils

# https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
# https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
Expand All @@ -21,4 +21,4 @@
settings.verbosity = logging.INFO
test_var = "test"

__all__ = ["settings", "_CONSTANTS", "data", "model", "external"]
__all__ = ["settings", "_CONSTANTS", "data", "model", "external", "utils"]
2 changes: 1 addition & 1 deletion scvi/external/cellassign/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import scvi
from scvi import _CONSTANTS
from scvi._docs import setup_anndata_dsp
from scvi.data import register_tensor_from_anndata
from scvi.data._anndata import _setup_anndata
from scvi.dataloaders import DataSplitter
from scvi.external.cellassign._module import CellAssignModule
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.train import LoudEarlyStopping, TrainingPlan, TrainRunner
from scvi.utils import setup_anndata_dsp

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.utils.data import DataLoader

from scvi import _CONSTANTS
from scvi._docs import setup_anndata_dsp
from scvi.data import transfer_anndata_setup
from scvi.data._anndata import _setup_anndata
from scvi.dataloaders import DataSplitter
Expand All @@ -22,6 +21,7 @@
)
from scvi.model.base import BaseModelClass, VAEMixin
from scvi.train import Trainer
from scvi.utils import setup_anndata_dsp

from ._module import JVAE
from ._task import GIMVITrainingPlan
Expand Down
2 changes: 1 addition & 1 deletion scvi/external/solo/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from anndata import AnnData

from scvi import _CONSTANTS
from scvi._docs import setup_anndata_dsp
from scvi.data._anndata import _setup_anndata, get_from_registry, transfer_anndata_setup
from scvi.dataloaders import DataSplitter
from scvi.model import SCVI
from scvi.model.base import BaseModelClass
from scvi.module import Classifier
from scvi.module.base import auto_move_data
from scvi.train import ClassifierTrainingPlan, LoudEarlyStopping, TrainRunner
from scvi.utils import setup_anndata_dsp

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/external/stereoscope/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from anndata import AnnData

from scvi._compat import Literal
from scvi._docs import setup_anndata_dsp
from scvi.data import register_tensor_from_anndata
from scvi.data._anndata import _setup_anndata
from scvi.external.stereoscope._module import RNADeconv, SpatialDeconv
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.utils import setup_anndata_dsp

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_amortizedlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from anndata import AnnData

from scvi._constants import _CONSTANTS
from scvi._docs import setup_anndata_dsp
from scvi.data._anndata import _setup_anndata
from scvi.module import AmortizedLDAPyroModule
from scvi.utils import setup_anndata_dsp

from .base import BaseModelClass, PyroSviTrainMixin

Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_autozi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from scvi import _CONSTANTS
from scvi._compat import Literal
from scvi._docs import setup_anndata_dsp
from scvi.data._anndata import _setup_anndata
from scvi.model._utils import _init_library_size
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.module import AutoZIVAE
from scvi.utils import setup_anndata_dsp

from .base import BaseModelClass, VAEMixin

Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from anndata import AnnData

from scvi import _CONSTANTS
from scvi._docs import setup_anndata_dsp
from scvi.data._anndata import _setup_anndata
from scvi.model.base import (
BaseModelClass,
Expand All @@ -16,6 +15,7 @@
VAEMixin,
)
from scvi.module import VAEC
from scvi.utils import setup_anndata_dsp

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import torch
from anndata import AnnData

from scvi._docs import setup_anndata_dsp
from scvi.data import register_tensor_from_anndata
from scvi.data._anndata import _setup_anndata
from scvi.model import CondSCVI
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin
from scvi.module import MRDeconv
from scvi.utils import setup_anndata_dsp

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_linear_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from anndata import AnnData

from scvi._compat import Literal
from scvi._docs import setup_anndata_dsp
from scvi.data._anndata import _setup_anndata
from scvi.model._utils import _get_var_names_from_setup_anndata, _init_library_size
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.module import LDVAE
from scvi.utils import setup_anndata_dsp

from .base import BaseModelClass, RNASeqMixin, VAEMixin

Expand Down
6 changes: 3 additions & 3 deletions scvi/model/_multivi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from scvi import _CONSTANTS
from scvi._compat import Literal
from scvi._docs import doc_differential_expression, setup_anndata_dsp
from scvi._utils import _doc_params
from scvi.data._anndata import _setup_anndata
from scvi.dataloaders import DataSplitter
Expand All @@ -25,6 +24,7 @@
from scvi.module import MULTIVAE
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.train._callbacks import SaveBestState
from scvi.utils._docstrings import doc_differential_expression, setup_anndata_dsp

from .base import BaseModelClass, VAEMixin
from .base._utils import _de_core
Expand Down Expand Up @@ -661,7 +661,7 @@ def differential_accessibility(
two_sided
Whether to perform a two-sided test, or a one-sided test.
**kwargs
Keyword args for :func:`scvi.utils.DifferentialComputation.get_bayes_factors`
Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`
Returns
-------
Expand Down Expand Up @@ -782,7 +782,7 @@ def differential_expression(
----------
{doc_differential_expression}
**kwargs
Keyword args for :func:`scvi.utils.DifferentialComputation.get_bayes_factors`
Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`
Returns
-------
Expand Down
4 changes: 2 additions & 2 deletions scvi/model/_peakvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from scipy.sparse import csr_matrix, vstack

from scvi._compat import Literal
from scvi._docs import doc_differential_expression, setup_anndata_dsp
from scvi._utils import _doc_params
from scvi.data._anndata import _setup_anndata
from scvi.model._utils import (
Expand All @@ -20,6 +19,7 @@
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.module import PEAKVAE
from scvi.train._callbacks import SaveBestState
from scvi.utils._docstrings import doc_differential_expression, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, VAEMixin
from .base._utils import _de_core
Expand Down Expand Up @@ -434,7 +434,7 @@ def differential_accessibility(
two_sided
Whether to perform a two-sided test, or a one-sided test.
**kwargs
Keyword args for :func:`scvi.utils.DifferentialComputation.get_bayes_factors`
Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from scvi import _CONSTANTS
from scvi._compat import Literal
from scvi._docs import setup_anndata_dsp
from scvi.data._anndata import _make_obs_column_categorical, _setup_anndata
from scvi.dataloaders import (
AnnDataLoader,
Expand All @@ -21,6 +20,7 @@
from scvi.module import SCANVAE
from scvi.train import SemiSupervisedTrainingPlan, TrainRunner
from scvi.train._callbacks import SubSampleLabels
from scvi.utils import setup_anndata_dsp

from ._scvi import SCVI
from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin
Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from anndata import AnnData

from scvi._compat import Literal
from scvi._docs import setup_anndata_dsp
from scvi.data._anndata import _setup_anndata
from scvi.model._utils import _init_library_size
from scvi.model.base import UnsupervisedTrainingMixin
from scvi.module import VAE
from scvi.utils import setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin

Expand Down
4 changes: 2 additions & 2 deletions scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from scvi import _CONSTANTS
from scvi._compat import Literal
from scvi._docs import doc_differential_expression, setup_anndata_dsp
from scvi._utils import _doc_params
from scvi.data import get_from_registry
from scvi.data._anndata import _setup_anndata
Expand All @@ -26,6 +25,7 @@
from scvi.model.base._utils import _de_core
from scvi.module import TOTALVAE
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.utils._docstrings import doc_differential_expression, setup_anndata_dsp

from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin

Expand Down Expand Up @@ -701,7 +701,7 @@ def differential_expression(
include_protein_background
Include the protein background component as part of the protein expression
**kwargs
Keyword args for :func:`scvi.utils.DifferentialComputation.get_bayes_factors`
Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`
Returns
-------
Expand Down
11 changes: 10 additions & 1 deletion scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from rich.text import Text

from scvi import _CONSTANTS, settings
from scvi._docs import setup_anndata_dsp
from scvi.data import get_from_registry, transfer_anndata_setup
from scvi.data._anndata import _check_anndata_setup_equivalence
from scvi.data._utils import _check_nonnegative_integers
from scvi.dataloaders import AnnDataLoader
from scvi.model._utils import parse_use_gpu_arg
from scvi.module.base import PyroBaseModuleClass
from scvi.utils import setup_anndata_dsp

from ._utils import _initialize_model, _load_saved_files, _validate_var_names

Expand Down Expand Up @@ -348,6 +348,15 @@ def load(
) = _load_saved_files(dir_path, load_adata, map_location=device)
adata = new_adata if new_adata is not None else adata

# Filter out keys that are no longer populated by setup_anndata.
# TODO(jhong): remove hack with setup_anndata refactor.
deprecated_keys = {"local_l_mean", "local_l_var"}
scvi_setup_dict["data_registry"] = {
k: v
for k, v in scvi_setup_dict["data_registry"].items()
if k not in deprecated_keys
}

_validate_var_names(adata, var_names)
transfer_anndata_setup(scvi_setup_dict, adata)
model = _initialize_model(cls, adata, attr_dict)
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions scvi/model/base/_rnamixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from scvi import _CONSTANTS
from scvi._compat import Literal
from scvi._docs import doc_differential_expression
from scvi._utils import _doc_params
from scvi.utils._docstrings import doc_differential_expression

from .._utils import (
_get_batch_code_from_category,
Expand Down Expand Up @@ -193,7 +193,7 @@ def differential_expression(
----------
{doc_differential_expression}
**kwargs
Keyword args for :func:`scvi.utils.DifferentialComputation.get_bayes_factors`
Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`
Returns
-------
Expand Down
10 changes: 3 additions & 7 deletions scvi/model/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from anndata import AnnData, read

from scvi._compat import Literal
from scvi._constants import _CONSTANTS
from scvi.utils import DifferentialComputation, track
from scvi.utils import track

from ._differential import DifferentialComputation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,11 +42,6 @@ def _load_saved_files(
attr_dict = pickle.load(handle)

scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")
# Only retain keys in the data_registry that exist in _CONSTANTS.
# TODO(jhong): Remove once data registry refactored.
scvi_setup_dict["data_registry"] = {
k: v for k, v in scvi_setup_dict["data_registry"].items() if k in _CONSTANTS
}

model_state_dict = torch.load(model_path, map_location=map_location)

Expand Down
4 changes: 2 additions & 2 deletions scvi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._differential import DifferentialComputation
from ._docstrings import setup_anndata_dsp
from ._track import track

__all__ = ["DifferentialComputation", "track"]
__all__ = ["track", "setup_anndata_dsp"]
Loading

0 comments on commit 6584927

Please sign in to comment.