Skip to content

Commit

Permalink
Backport PR #1434: add arches query data prep
Browse files Browse the repository at this point in the history
  • Loading branch information
adamgayoso authored and meeseeksmachine committed Mar 15, 2022
1 parent 7455878 commit 778c8ea
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 18 deletions.
30 changes: 26 additions & 4 deletions scvi/data/_compat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
from typing import Optional

import numpy as np
from anndata import AnnData
Expand All @@ -10,6 +11,7 @@
from .fields import (
CategoricalJointObsField,
CategoricalObsField,
LabelsWithUnlabeledObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
Expand All @@ -27,7 +29,9 @@
}


def registry_from_setup_dict(setup_dict: dict) -> dict:
def registry_from_setup_dict(
setup_dict: dict, unlabeled_category: Optional[str] = None
) -> dict:
"""
Converts old setup dict format to new registry dict format.
Expand Down Expand Up @@ -80,6 +84,10 @@ def registry_from_setup_dict(setup_dict: dict) -> dict:
field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_batch"]
elif new_registry_key == REGISTRY_KEYS.LABELS_KEY:
field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_labels"]
if unlabeled_category is not None:
field_state_registry[
LabelsWithUnlabeledObsField.UNLABELED_CATEGORY
] = unlabeled_category
elif attr_name == _constants._ADATA_ATTRS.OBSM:
if new_registry_key == REGISTRY_KEYS.CONT_COVS_KEY:
columns = setup_dict["extra_continuous_keys"].copy()
Expand All @@ -106,7 +114,11 @@ def registry_from_setup_dict(setup_dict: dict) -> dict:


def manager_from_setup_dict(
cls, adata: AnnData, setup_dict: dict, **transfer_kwargs
cls,
adata: AnnData,
setup_dict: dict,
unlabeled_category: Optional[str] = None,
**transfer_kwargs,
) -> AnnDataManager:
"""
Creates an :class:`~scvi.data.AnnDataManager` given only a scvi-tools setup dictionary.
Expand Down Expand Up @@ -145,7 +157,15 @@ def manager_from_setup_dict(
elif attr_name == _constants._ADATA_ATTRS.OBS:
if new_registry_key in {REGISTRY_KEYS.BATCH_KEY, REGISTRY_KEYS.LABELS_KEY}:
original_key = categorical_mappings[attr_key]["original_key"]
field = CategoricalObsField(new_registry_key, original_key)
if (
unlabeled_category is not None
and new_registry_key == REGISTRY_KEYS.LABELS_KEY
):
field = LabelsWithUnlabeledObsField(
new_registry_key, original_key, unlabeled_category
)
else:
field = CategoricalObsField(new_registry_key, original_key)
setup_args[f"{new_registry_key}_key"] = original_key
elif new_registry_key == REGISTRY_KEYS.INDICES_KEY:
adata.obs[attr_key] = np.arange(adata.n_obs).astype("int64")
Expand Down Expand Up @@ -187,7 +207,9 @@ def manager_from_setup_dict(
}
adata_manager = AnnDataManager(fields=fields, setup_method_args=setup_method_args)

source_registry = registry_from_setup_dict(setup_dict)
source_registry = registry_from_setup_dict(
setup_dict, unlabeled_category=unlabeled_category
)
adata_manager.register_fields(
adata, source_registry=source_registry, **transfer_kwargs
)
Expand Down
1 change: 0 additions & 1 deletion scvi/data/fields/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class LabelsWithUnlabeledObsField(CategoricalObsField):
"""

UNLABELED_CATEGORY = "unlabeled_category"
WAS_REMAPPED = "was_remapped"

def __init__(
self,
Expand Down
119 changes: 107 additions & 12 deletions scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from copy import deepcopy
from typing import Optional, Union

import anndata
import numpy as np
import pandas as pd
import torch
from anndata import AnnData
from scipy.sparse import csr_matrix

from scvi.data import _constants
from scvi.data._compat import manager_from_setup_dict
Expand All @@ -17,6 +21,8 @@

logger = logging.getLogger(__name__)

MIN_VAR_NAME_RATIO = 0.8


class ArchesMixin:
"""Universal scArches implementation."""
Expand Down Expand Up @@ -44,7 +50,7 @@ def load_query_data(
adata
AnnData organized in the same way as data used to train model.
It is not necessary to run setup_anndata,
as AnnData is validated against the saved `scvi` setup dictionary.
as AnnData is validated against the ``registry``.
reference_model
Either an already instantiated model of the same class, or a path to
saved outputs for reference model.
Expand All @@ -70,15 +76,10 @@ def load_query_data(
Whether to freeze classifier completely. Only applies to `SCANVI`.
"""
use_gpu, device = parse_use_gpu_arg(use_gpu)
if isinstance(reference_model, str):
attr_dict, var_names, load_state_dict, _ = _load_saved_files(
reference_model, load_adata=False, map_location=device
)
else:
attr_dict = reference_model._get_user_attributes()
attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
var_names = reference_model.adata.var_names
load_state_dict = deepcopy(reference_model.module.state_dict())

attr_dict, var_names, load_state_dict = _get_loaded_data(
reference_model, device=device
)

if inplace_subset_query_vars:
logger.debug("Subsetting query vars to reference vars.")
Expand All @@ -87,9 +88,16 @@ def load_query_data(

if "scvi_setup_dict_" in attr_dict:
scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")
# scanvi case
unlabeled_category_key = "unlabeled_category_"
unlabeled_category = attr_dict.get(unlabeled_category_key, None)
cls.register_manager(
manager_from_setup_dict(
cls, adata, scvi_setup_dict, extend_categories=True
cls,
adata,
scvi_setup_dict,
extend_categories=True,
unlabeled_category=unlabeled_category,
)
)
else:
Expand All @@ -112,7 +120,7 @@ def load_query_data(
adata,
source_registry=registry,
extend_categories=True,
**registry[_SETUP_ARGS_KEY]
**registry[_SETUP_ARGS_KEY],
)

model = _initialize_model(cls, adata, attr_dict)
Expand Down Expand Up @@ -155,6 +163,79 @@ def load_query_data(

return model

@staticmethod
def prepare_query_anndata(
adata: AnnData,
reference_model: Union[str, BaseModelClass],
return_reference_var_names: bool = False,
inplace: bool = True,
) -> Optional[Union[AnnData, pd.Index]]:
"""
Prepare data for query integration.
This function will return a new AnnData object with padded zeros
for missing features, as well as correctly sorted features.
Parameters
----------
adata
AnnData organized in the same way as data used to train model.
It is not necessary to run setup_anndata,
as AnnData is validated against the ``registry``.
reference_model
Either an already instantiated model of the same class, or a path to
saved outputs for reference model.
return_reference_var_names
Only load and return reference var names if True.
inplace
Whether to subset and rearrange query vars inplace or return new AnnData.
Returns
-------
Query adata ready to use in `load_query_data` unless `return_reference_var_names`
in which case a pd.Index of reference var names is returned.
"""
_, var_names, _ = _get_loaded_data(reference_model)
var_names = pd.Index(var_names)

if return_reference_var_names:
return var_names

intersection = adata.var_names.intersection(var_names)
inter_len = len(intersection)
if inter_len == 0:
raise ValueError(
"No reference var names found in query data. "
"Please rerun with return_reference_var_names=True "
"to see reference var names."
)

ratio = inter_len / len(var_names)
logger.info("Found {}% reference vars in query data.".format(ratio * 100))
if ratio < MIN_VAR_NAME_RATIO:
warnings.warn(
f"Query data contains less than {MIN_VAR_NAME_RATIO:.0f}% of reference var names. "
"This may result in poor performance."
)
genes_to_add = var_names.difference(adata.var_names)
adata_padding = AnnData(csr_matrix(np.zeros((adata.n_obs, len(genes_to_add)))))
adata_padding.var_names = genes_to_add
adata_padding.obs_names = adata.obs_names
# Concatenate object
adata_out = anndata.concat(
[adata, adata_padding],
axis=1,
join="outer",
index_unique=None,
merge="unique",
)
adata_out._inplace_subset_var(var_names)

if inplace:
adata._init_as_actual(adata_out, dtype=adata._X.dtype)
else:
return adata_out


def _set_params_online_update(
module,
Expand Down Expand Up @@ -229,3 +310,17 @@ def requires_grad(key):
par.requires_grad = True
else:
par.requires_grad = False


def _get_loaded_data(reference_model, device=None):
if isinstance(reference_model, str):
attr_dict, var_names, load_state_dict, _ = _load_saved_files(
reference_model, load_adata=False, map_location=device
)
else:
attr_dict = reference_model._get_user_attributes()
attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"}
var_names = reference_model.adata.var_names
load_state_dict = deepcopy(reference_model.module.state_dict())

return attr_dict, var_names, load_state_dict
11 changes: 10 additions & 1 deletion scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,16 @@ def load(
# Legacy support for old setup dict format.
if "scvi_setup_dict_" in attr_dict:
scvi_setup_dict = attr_dict.pop("scvi_setup_dict_")
cls.register_manager(manager_from_setup_dict(cls, adata, scvi_setup_dict))
unlabeled_category_key = "unlabeled_category_"
unlabeled_category = attr_dict.get(unlabeled_category_key, None)
cls.register_manager(
manager_from_setup_dict(
cls,
adata,
scvi_setup_dict,
unlabeled_category=unlabeled_category,
)
)
else:
registry = attr_dict.pop("registry_")
if (
Expand Down
6 changes: 6 additions & 0 deletions scvi/model/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def _initialize_model(cls, adata, attr_dict):
kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()}
non_kwargs.pop("use_cuda")

# backwards compat for scANVI
if "unlabeled_category" in non_kwargs.keys():
non_kwargs.pop("unlabeled_category")
if "pretrained_model" in non_kwargs.keys():
non_kwargs.pop("pretrained_model")

model = cls(adata, **non_kwargs, **kwargs)
for attr, val in attr_dict.items():
setattr(model, attr, val)
Expand Down
36 changes: 36 additions & 0 deletions tests/models/test_scarches.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,42 @@ def single_pass_for_online_update(model):
scvi_loss.loss.backward()


def test_data_prep(save_path):
n_latent = 5
adata1 = synthetic_iid()
SCVI.setup_anndata(adata1, batch_key="batch", labels_key="labels")
model = SCVI(adata1, n_latent=n_latent)
model.train(1, check_val_every_n_epoch=1)
dir_path = os.path.join(save_path, "saved_model/")
model.save(dir_path, overwrite=True)

# adata2 has more genes and a perfect subset of adata1
adata2 = synthetic_iid(n_genes=110)
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
SCVI.prepare_query_anndata(adata2, dir_path)
SCVI.load_query_data(adata2, dir_path)

adata3 = SCVI.prepare_query_anndata(adata2, dir_path, inplace=False)
SCVI.load_query_data(adata3, dir_path)

# adata4 has more genes and missing 10 genes from adata1
adata4 = synthetic_iid(n_genes=110)
new_var_names_init = [f"Random {i}" for i in range(10)]
new_var_names = new_var_names_init + adata4.var_names[10:].to_list()
adata4.var_names = new_var_names

SCVI.prepare_query_anndata(adata4, dir_path)
# should be padded 0s
assert np.sum(adata4[:, adata4.var_names[:10]].X) == 0
np.testing.assert_equal(
adata4.var_names[:10].to_numpy(), adata1.var_names[:10].to_numpy()
)
SCVI.load_query_data(adata4, dir_path)

adata5 = SCVI.prepare_query_anndata(adata4, dir_path, inplace=False)
SCVI.load_query_data(adata5, dir_path)


def test_scvi_online_update(save_path):
n_latent = 5
adata1 = synthetic_iid()
Expand Down

0 comments on commit 778c8ea

Please sign in to comment.