Skip to content

Commit

Permalink
SCANVI bug fixes (scverse#1458)
Browse files Browse the repository at this point in the history
* address issue 1449

* fix scanvi covariate handling

* fix tests

* raise error on extra cat covs

* release note
  • Loading branch information
justjhong authored and nrclaudio committed Jun 21, 2022
1 parent 540fad3 commit f63bdb0
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 11 deletions.
18 changes: 18 additions & 0 deletions docs/release_notes/v0.15.3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# New in 0.15.3 (2022-MM-DD)

## Changes

## Bug fixes

- Raise `NotImplementedError` when `categorical_covariate_keys` are used with {meth}`scvi.model.SCANVI.load_query_data`. ([#1458]).
- Fix behavior when `continuous_covariate_keys` are used with {meth}`scvi.model.SCANVI.classify`. ([#1458]).
- Unlabeled category values are automatically populated when {meth}`scvi.model.SCANVI.load_query_data` run on `adata_target` missing labels column. ([#1458]).

## Contributors

- [@jjhong922]
- [@adamgayoso]

[#1458]: https://github.com/YosefLab/scvi-tools/pull/1458
[@adamgayoso]: https://github.com/adamgayoso
[@jjhong922]: https://github.com/jjhong922
23 changes: 19 additions & 4 deletions scvi/data/fields/_scanvi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
from typing import Optional, Union

import numpy as np
from anndata import AnnData
from pandas.api.types import CategoricalDtype

from scvi.data._utils import _make_column_categorical
from scvi.data._utils import _make_column_categorical, _set_data_in_registry

from ._obs_field import CategoricalObsField

Expand Down Expand Up @@ -68,9 +69,6 @@ def _remap_unlabeled_to_final_category(
}

def register_field(self, adata: AnnData) -> dict:
if self.is_default:
self._setup_default_attr(adata)

state_registry = super().register_field(adata)
mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
return self._remap_unlabeled_to_final_category(adata, mapping)
Expand All @@ -80,8 +78,25 @@ def transfer_field(
state_registry: dict,
adata_target: AnnData,
extend_categories: bool = False,
allow_missing_labels: bool = False,
**kwargs,
) -> dict:
if (
allow_missing_labels
and self.attr_key is not None
and self.attr_key not in adata_target.obs
):
# Fill in original .obs attribute with unlabeled_category values.
warnings.warn(
f"Missing labels key {self.attr_key}. Filling in with unlabeled category {self._unlabeled_category}."
)
_set_data_in_registry(
adata_target,
self._unlabeled_category,
self.attr_name,
self._original_attr_key,
)

transfer_state_registry = super().transfer_field(
state_registry, adata_target, extend_categories=extend_categories, **kwargs
)
Expand Down
11 changes: 10 additions & 1 deletion scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,16 @@ def predict(
for _, tensors in enumerate(scdl):
x = tensors[REGISTRY_KEYS.X_KEY]
batch = tensors[REGISTRY_KEYS.BATCH_KEY]
pred = self.module.classify(x, batch)

cont_key = REGISTRY_KEYS.CONT_COVS_KEY
cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None

cat_key = REGISTRY_KEYS.CAT_COVS_KEY
cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None

pred = self.module.classify(
x, batch_index=batch, cat_covs=cat_covs, cont_covs=cont_covs
)
if not soft:
pred = pred.argmax(dim=1)
y_pred.append(pred.detach().cpu())
Expand Down
8 changes: 8 additions & 0 deletions scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anndata import AnnData
from scipy.sparse import csr_matrix

from scvi import REGISTRY_KEYS
from scvi.data import _constants
from scvi.data._compat import manager_from_setup_dict
from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
Expand Down Expand Up @@ -97,6 +98,7 @@ def load_query_data(
adata,
scvi_setup_dict,
extend_categories=True,
allow_missing_labels=True,
unlabeled_category=unlabeled_category,
)
)
Expand All @@ -120,12 +122,18 @@ def load_query_data(
adata,
source_registry=registry,
extend_categories=True,
allow_missing_labels=True,
**registry[_SETUP_ARGS_KEY],
)

model = _initialize_model(cls, adata, attr_dict)
adata_manager = model.get_anndata_manager(adata, required=True)

if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry:
raise NotImplementedError(
"scArches currently does not support models with extra categorical covariates."
)

version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".")
if int(version_split[1]) < 8 and int(version_split[0]) == 0:
warnings.warn(
Expand Down
28 changes: 25 additions & 3 deletions scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,20 @@ def __init__(
)

@auto_move_data
def classify(self, x, batch_index=None):
def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None):
if self.log_variational:
x = torch.log(1 + x)
qz_m, _, z = self.z_encoder(x, batch_index)

if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((x, cont_covs), dim=-1)
else:
encoder_input = x
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = tuple()

qz_m, _, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
# We classify using the inferred mean parameter of z_1 in the latent space
z = qz_m
if self.use_labels_groups:
Expand All @@ -210,8 +220,20 @@ def classification_loss(self, labelled_dataset):
x = labelled_dataset[REGISTRY_KEYS.X_KEY]
y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY]
batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY]
cont_key = REGISTRY_KEYS.CONT_COVS_KEY
cont_covs = (
labelled_dataset[cont_key] if cont_key in labelled_dataset.keys() else None
)

cat_key = REGISTRY_KEYS.CAT_COVS_KEY
cat_covs = (
labelled_dataset[cat_key] if cat_key in labelled_dataset.keys() else None
)
classification_loss = F.cross_entropy(
self.classify(x, batch_idx), y.view(-1).long()
self.classify(
x, batch_index=batch_idx, cat_covs=cat_covs, cont_covs=cont_covs
),
y.view(-1).long(),
)
return classification_loss

Expand Down
4 changes: 2 additions & 2 deletions scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1):
if self.log_variational:
x_ = torch.log(1 + x_)

if cont_covs is not None and self.encode_covariates is True:
if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((x_, cont_covs), dim=-1)
else:
encoder_input = x_
if cat_covs is not None and self.encode_covariates is True:
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = tuple()
Expand Down
40 changes: 40 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,46 @@ def test_multiple_covariates_scvi(save_path):
m.train(1)


def test_multiple_encoded_covariates_scvi(save_path):
adata = synthetic_iid()
adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],))
adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],))
adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0],))
adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0],))

SCVI.setup_anndata(
adata,
batch_key="batch",
labels_key="labels",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
m = SCVI(adata, encode_covariates=True)
m.train(1)

SCANVI.setup_anndata(
adata,
"labels",
"Unknown",
batch_key="batch",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
m = SCANVI(adata, encode_covariates=True)
m.train(1)

TOTALVI.setup_anndata(
adata,
batch_key="batch",
protein_expression_obsm_key="protein_expression",
protein_names_uns_key="protein_names",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
m = TOTALVI(adata, encode_covariates=True)
m.train(1)


def test_peakvi():
data = synthetic_iid()
PEAKVI.setup_anndata(
Expand Down
63 changes: 62 additions & 1 deletion tests/models/test_scarches.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,16 @@ def test_scanvi_online_update(save_path):
new_labels = adata1.obs.labels.to_numpy()
new_labels[0] = "Unknown"
adata1.obs["labels"] = pd.Categorical(new_labels)
SCANVI.setup_anndata(adata1, "labels", "Unknown", batch_key="batch")
adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],))
adata1.obs["cont2"] = np.random.normal(size=(adata1.shape[0],))

SCANVI.setup_anndata(
adata1,
"labels",
"Unknown",
batch_key="batch",
continuous_covariate_keys=["cont1", "cont2"],
)
model = SCANVI(
adata1,
n_latent=n_latent,
Expand All @@ -216,6 +225,20 @@ def test_scanvi_online_update(save_path):
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata2.obs["labels"] = "Unknown"
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))

model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
model.train(max_epochs=1)
model.get_latent_representation()
model.predict()

# query has all missing labels and no labels key
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
del adata2.obs["labels"]
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))

model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
model.train(max_epochs=1)
Expand All @@ -225,12 +248,50 @@ def test_scanvi_online_update(save_path):
# query has no missing labels
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))

model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
model.train(max_epochs=1)
model.get_latent_representation()
model.predict()

# Test error on extra categoricals
adata1 = synthetic_iid()
new_labels = adata1.obs.labels.to_numpy()
new_labels[0] = "Unknown"
adata1.obs["labels"] = pd.Categorical(new_labels)
adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],))
adata1.obs["cont2"] = np.random.normal(size=(adata1.shape[0],))
adata1.obs["cat1"] = np.random.randint(0, 5, size=(adata1.shape[0],))
adata1.obs["cat2"] = np.random.randint(0, 5, size=(adata1.shape[0],))
SCANVI.setup_anndata(
adata1,
"labels",
"Unknown",
batch_key="batch",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
model = SCANVI(
adata1,
n_latent=n_latent,
encode_covariates=True,
)
model.train(max_epochs=1, check_val_every_n_epoch=1)
dir_path = os.path.join(save_path, "saved_model/")
model.save(dir_path, overwrite=True)

adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata2.obs["labels"] = "Unknown"
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cat1"] = np.random.randint(0, 5, size=(adata2.shape[0],))
adata2.obs["cat2"] = np.random.randint(0, 5, size=(adata2.shape[0],))
with pytest.raises(NotImplementedError):
SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)

# ref has fully-observed labels
n_latent = 5
adata1 = synthetic_iid()
Expand Down

0 comments on commit f63bdb0

Please sign in to comment.