Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #1458 on branch 0.15.x (SCANVI bug fixes) #1459

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/release_notes/v0.15.3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# New in 0.15.3 (2022-MM-DD)

## Changes

## Bug fixes

- Raise `NotImplementedError` when `categorical_covariate_keys` are used with {meth}`scvi.model.SCANVI.load_query_data`. ([#1458]).
- Fix behavior when `continuous_covariate_keys` are used with {meth}`scvi.model.SCANVI.classify`. ([#1458]).
- Unlabeled category values are automatically populated when {meth}`scvi.model.SCANVI.load_query_data` run on `adata_target` missing labels column. ([#1458]).

## Contributors

- [@jjhong922]
- [@adamgayoso]

[#1458]: https://github.com/YosefLab/scvi-tools/pull/1458
[@adamgayoso]: https://github.com/adamgayoso
[@jjhong922]: https://github.com/jjhong922
23 changes: 19 additions & 4 deletions scvi/data/fields/_scanvi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings
from typing import Optional, Union

import numpy as np
from anndata import AnnData
from pandas.api.types import CategoricalDtype

from scvi.data._utils import _make_column_categorical
from scvi.data._utils import _make_column_categorical, _set_data_in_registry

from ._obs_field import CategoricalObsField

Expand Down Expand Up @@ -68,9 +69,6 @@ def _remap_unlabeled_to_final_category(
}

def register_field(self, adata: AnnData) -> dict:
if self.is_default:
self._setup_default_attr(adata)

state_registry = super().register_field(adata)
mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
return self._remap_unlabeled_to_final_category(adata, mapping)
Expand All @@ -80,8 +78,25 @@ def transfer_field(
state_registry: dict,
adata_target: AnnData,
extend_categories: bool = False,
allow_missing_labels: bool = False,
**kwargs,
) -> dict:
if (
allow_missing_labels
and self.attr_key is not None
and self.attr_key not in adata_target.obs
):
# Fill in original .obs attribute with unlabeled_category values.
warnings.warn(
f"Missing labels key {self.attr_key}. Filling in with unlabeled category {self._unlabeled_category}."
)
_set_data_in_registry(
adata_target,
self._unlabeled_category,
self.attr_name,
self._original_attr_key,
)

transfer_state_registry = super().transfer_field(
state_registry, adata_target, extend_categories=extend_categories, **kwargs
)
Expand Down
11 changes: 10 additions & 1 deletion scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,16 @@ def predict(
for _, tensors in enumerate(scdl):
x = tensors[REGISTRY_KEYS.X_KEY]
batch = tensors[REGISTRY_KEYS.BATCH_KEY]
pred = self.module.classify(x, batch)

cont_key = REGISTRY_KEYS.CONT_COVS_KEY
cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None

cat_key = REGISTRY_KEYS.CAT_COVS_KEY
cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None

pred = self.module.classify(
x, batch_index=batch, cat_covs=cat_covs, cont_covs=cont_covs
)
if not soft:
pred = pred.argmax(dim=1)
y_pred.append(pred.detach().cpu())
Expand Down
8 changes: 8 additions & 0 deletions scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anndata import AnnData
from scipy.sparse import csr_matrix

from scvi import REGISTRY_KEYS
from scvi.data import _constants
from scvi.data._compat import manager_from_setup_dict
from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
Expand Down Expand Up @@ -97,6 +98,7 @@ def load_query_data(
adata,
scvi_setup_dict,
extend_categories=True,
allow_missing_labels=True,
unlabeled_category=unlabeled_category,
)
)
Expand All @@ -120,12 +122,18 @@ def load_query_data(
adata,
source_registry=registry,
extend_categories=True,
allow_missing_labels=True,
**registry[_SETUP_ARGS_KEY],
)

model = _initialize_model(cls, adata, attr_dict)
adata_manager = model.get_anndata_manager(adata, required=True)

if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry:
raise NotImplementedError(
"scArches currently does not support models with extra categorical covariates."
)

version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".")
if int(version_split[1]) < 8 and int(version_split[0]) == 0:
warnings.warn(
Expand Down
28 changes: 25 additions & 3 deletions scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,20 @@ def __init__(
)

@auto_move_data
def classify(self, x, batch_index=None):
def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None):
if self.log_variational:
x = torch.log(1 + x)
qz_m, _, z = self.z_encoder(x, batch_index)

if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((x, cont_covs), dim=-1)
else:
encoder_input = x
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = tuple()

qz_m, _, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
# We classify using the inferred mean parameter of z_1 in the latent space
z = qz_m
if self.use_labels_groups:
Expand All @@ -210,8 +220,20 @@ def classification_loss(self, labelled_dataset):
x = labelled_dataset[REGISTRY_KEYS.X_KEY]
y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY]
batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY]
cont_key = REGISTRY_KEYS.CONT_COVS_KEY
cont_covs = (
labelled_dataset[cont_key] if cont_key in labelled_dataset.keys() else None
)

cat_key = REGISTRY_KEYS.CAT_COVS_KEY
cat_covs = (
labelled_dataset[cat_key] if cat_key in labelled_dataset.keys() else None
)
classification_loss = F.cross_entropy(
self.classify(x, batch_idx), y.view(-1).long()
self.classify(
x, batch_index=batch_idx, cat_covs=cat_covs, cont_covs=cont_covs
),
y.view(-1).long(),
)
return classification_loss

Expand Down
4 changes: 2 additions & 2 deletions scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1):
if self.log_variational:
x_ = torch.log(1 + x_)

if cont_covs is not None and self.encode_covariates is True:
if cont_covs is not None and self.encode_covariates:
encoder_input = torch.cat((x_, cont_covs), dim=-1)
else:
encoder_input = x_
if cat_covs is not None and self.encode_covariates is True:
if cat_covs is not None and self.encode_covariates:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
categorical_input = tuple()
Expand Down
40 changes: 40 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,46 @@ def test_multiple_covariates_scvi(save_path):
m.train(1)


def test_multiple_encoded_covariates_scvi(save_path):
adata = synthetic_iid()
adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],))
adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],))
adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0],))
adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0],))

SCVI.setup_anndata(
adata,
batch_key="batch",
labels_key="labels",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
m = SCVI(adata, encode_covariates=True)
m.train(1)

SCANVI.setup_anndata(
adata,
"labels",
"Unknown",
batch_key="batch",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
m = SCANVI(adata, encode_covariates=True)
m.train(1)

TOTALVI.setup_anndata(
adata,
batch_key="batch",
protein_expression_obsm_key="protein_expression",
protein_names_uns_key="protein_names",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
m = TOTALVI(adata, encode_covariates=True)
m.train(1)


def test_peakvi():
data = synthetic_iid()
PEAKVI.setup_anndata(
Expand Down
63 changes: 62 additions & 1 deletion tests/models/test_scarches.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,16 @@ def test_scanvi_online_update(save_path):
new_labels = adata1.obs.labels.to_numpy()
new_labels[0] = "Unknown"
adata1.obs["labels"] = pd.Categorical(new_labels)
SCANVI.setup_anndata(adata1, "labels", "Unknown", batch_key="batch")
adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],))
adata1.obs["cont2"] = np.random.normal(size=(adata1.shape[0],))

SCANVI.setup_anndata(
adata1,
"labels",
"Unknown",
batch_key="batch",
continuous_covariate_keys=["cont1", "cont2"],
)
model = SCANVI(
adata1,
n_latent=n_latent,
Expand All @@ -216,6 +225,20 @@ def test_scanvi_online_update(save_path):
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata2.obs["labels"] = "Unknown"
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))

model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
model.train(max_epochs=1)
model.get_latent_representation()
model.predict()

# query has all missing labels and no labels key
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
del adata2.obs["labels"]
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))

model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
model.train(max_epochs=1)
Expand All @@ -225,12 +248,50 @@ def test_scanvi_online_update(save_path):
# query has no missing labels
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))

model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
model.train(max_epochs=1)
model.get_latent_representation()
model.predict()

# Test error on extra categoricals
adata1 = synthetic_iid()
new_labels = adata1.obs.labels.to_numpy()
new_labels[0] = "Unknown"
adata1.obs["labels"] = pd.Categorical(new_labels)
adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],))
adata1.obs["cont2"] = np.random.normal(size=(adata1.shape[0],))
adata1.obs["cat1"] = np.random.randint(0, 5, size=(adata1.shape[0],))
adata1.obs["cat2"] = np.random.randint(0, 5, size=(adata1.shape[0],))
SCANVI.setup_anndata(
adata1,
"labels",
"Unknown",
batch_key="batch",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
model = SCANVI(
adata1,
n_latent=n_latent,
encode_covariates=True,
)
model.train(max_epochs=1, check_val_every_n_epoch=1)
dir_path = os.path.join(save_path, "saved_model/")
model.save(dir_path, overwrite=True)

adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata2.obs["labels"] = "Unknown"
adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],))
adata2.obs["cat1"] = np.random.randint(0, 5, size=(adata2.shape[0],))
adata2.obs["cat2"] = np.random.randint(0, 5, size=(adata2.shape[0],))
with pytest.raises(NotImplementedError):
SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)

# ref has fully-observed labels
n_latent = 5
adata1 = synthetic_iid()
Expand Down