Skip to content

Commit

Permalink
Merge branch 'master' into jhong/1163
Browse files Browse the repository at this point in the history
  • Loading branch information
adamgayoso authored Feb 15, 2022
2 parents 4b14dae + aab6094 commit 0f7456d
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 47 deletions.
13 changes: 1 addition & 12 deletions scvi/data/anndata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def _make_obs_column_categorical(
column_key,
alternate_column_key,
categorical_dtype=None,
return_mapping=False,
):
"""
Makes the data in column_key in obs all categorical.
Expand All @@ -146,16 +145,6 @@ def _make_obs_column_categorical(
)
adata.obs[alternate_column_key] = codes

if not return_mapping:
# store categorical mappings
store_dict = {
alternate_column_key: {"original_key": column_key, "mapping": mapping}
}

if "categorical_mappings" not in adata.uns["_scvi"].keys():
adata.uns["_scvi"]["categorical_mappings"] = dict()
adata.uns["_scvi"]["categorical_mappings"].update(store_dict)

# make sure each category contains enough cells
unique, counts = np.unique(adata.obs[alternate_column_key], return_counts=True)
if np.min(counts) < 3:
Expand All @@ -171,7 +160,7 @@ def _make_obs_column_categorical(
"Is adata.obs['{}'] continuous? SCVI doesn't support continuous obs yet."
)

return mapping if return_mapping else alternate_column_key
return mapping


def _assign_adata_uuid(adata: anndata.AnnData, overwrite: bool = False) -> None:
Expand Down
5 changes: 3 additions & 2 deletions scvi/data/anndata/fields/_obs_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def register_field(self, adata: AnnData) -> dict:

super().register_field(adata)
categorical_mapping = _make_obs_column_categorical(
adata, self._original_attr_key, self.attr_key, return_mapping=True
adata,
self._original_attr_key,
self.attr_key,
)
return {
self.CATEGORICAL_MAPPING_KEY: categorical_mapping,
Expand Down Expand Up @@ -174,7 +176,6 @@ def transfer_field(
self._original_attr_key,
self.attr_key,
categorical_dtype=cat_dtype,
return_mapping=True,
)
return {
self.CATEGORICAL_MAPPING_KEY: new_mapping,
Expand Down
26 changes: 13 additions & 13 deletions scvi/data/anndata/fields/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,24 @@ def _remap_unlabeled_to_final_category(
unlabeled_idx = unlabeled_idx[0][0]
# move unlabeled category to be the last position
mapping[unlabeled_idx], mapping[-1] = mapping[-1], mapping[unlabeled_idx]
cat_dtype = CategoricalDtype(categories=mapping, ordered=True)
# rerun setup for the batch column
mapping = _make_obs_column_categorical(
adata,
self._original_attr_key,
self.attr_key,
categorical_dtype=cat_dtype,
return_mapping=True,
)
remapped = True
else:
remapped = False
# could be in mapping in transfer case
elif self._unlabeled_category not in mapping:
# just put as last category
mapping = np.asarray(list(mapping) + [self._unlabeled_category])

cat_dtype = CategoricalDtype(categories=mapping, ordered=True)
# rerun setup for the batch column
mapping = _make_obs_column_categorical(
adata,
self._original_attr_key,
self.attr_key,
categorical_dtype=cat_dtype,
)

return {
self.CATEGORICAL_MAPPING_KEY: mapping,
self.ORIGINAL_ATTR_KEY: self._original_attr_key,
self.UNLABELED_CATEGORY: self._unlabeled_category,
self.WAS_REMAPPED: remapped,
}

def register_field(self, adata: AnnData) -> dict:
Expand Down
13 changes: 5 additions & 8 deletions scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,7 @@ def __init__(
self._dl_cls = AnnDataLoader

# ignores unlabeled catgegory
n_labels = (
self.summary_stats.n_labels - 1
if self.has_unlabeled
else self.summary_stats.n_labels
)
n_labels = self.summary_stats.n_labels - 1
n_cats_per_cov = (
self.adata_manager.get_state_registry(
REGISTRY_KEYS.CAT_COVS_KEY
Expand Down Expand Up @@ -212,7 +208,9 @@ def from_scvi_model(
adata = scvi_model.adata

scvi_setup_kwargs = scvi_model.adata_manager.registry[_SETUP_KWARGS_KEY]
cls.setup_anndata(adata, unlabeled_category, **scvi_setup_kwargs)
cls.setup_anndata(
adata, unlabeled_category=unlabeled_category, **scvi_setup_kwargs
)
scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)
scvi_state_dict = scvi_model.module.state_dict()
scanvi_model.module.load_state_dict(scvi_state_dict, strict=False)
Expand All @@ -228,7 +226,6 @@ def _set_indices_and_labels(self):
REGISTRY_KEYS.LABELS_KEY
)
self.unlabeled_category_ = labels_state_registry.unlabeled_category
self.has_unlabeled = labels_state_registry.was_remapped

labels = self.get_from_registry(self.adata, REGISTRY_KEYS.LABELS_KEY)
self._label_mapping = labels_state_registry.categorical_mapping
Expand Down Expand Up @@ -387,10 +384,10 @@ def train(
def setup_anndata(
cls,
adata: AnnData,
labels_key: str,
unlabeled_category: Union[str, int, float],
layer: Optional[str] = None,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
size_factor_key: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
Expand Down
28 changes: 24 additions & 4 deletions scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class BaseModelClass(metaclass=BaseModelMetaClass):
def __init__(self, adata: Optional[AnnData] = None):
self.id = str(uuid4()) # Used for cls._manager_store keys.
if adata is not None:
self.adata = adata
self.adata_manager = self._get_most_recent_anndata_manager(
self._adata = adata
self._adata_manager = self._get_most_recent_anndata_manager(
adata, required=True
)
self._register_manager_for_instance(self.adata_manager)
# Suffix registry instance variable with _ to include it when saving the model.
self.registry_ = self.adata_manager.registry
self.summary_stats = self.adata_manager.summary_stats
self.registry_ = self._adata_manager.registry
self.summary_stats = self._adata_manager.summary_stats

self.is_trained_ = False
self._model_summary_string = ""
Expand All @@ -81,6 +81,26 @@ def __init__(self, adata: Optional[AnnData] = None):
self.history_ = None
self._data_loader_cls = AnnDataLoader

@property
def adata(self) -> AnnData:
"""Data attached to model instance."""
return self._adata

@adata.setter
def adata(self, adata: AnnData):
if adata is None:
raise ValueError("adata cannot be None.")
self._validate_anndata(adata)
self._adata = adata
self._adata_manager = self.get_anndata_manager(adata, required=True)
self.registry_ = self._adata_manager.registry
self.summary_stats = self._adata_manager.summary_stats

@property
def adata_manager(self) -> AnnDataManager:
"""Manager instance associated with self.adata."""
return self._adata_manager

def to_device(self, device: Union[str, int]):
"""
Move model to device.
Expand Down
43 changes: 38 additions & 5 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,33 @@ def test_scvi_sparse(save_path):
model.differential_expression(groupby="labels", group1="label_1")


def test_setting_adata_attr():
n_latent = 5
adata = synthetic_iid()
SCVI.setup_anndata(adata, batch_key="batch")
model = SCVI(adata, n_latent=n_latent)
model.train(1, train_size=0.5)

adata2 = synthetic_iid()
model.adata = adata2

with pytest.raises(AssertionError):
rep = model.get_latent_representation(adata)
rep2 = model.get_latent_representation()
np.testing.assert_array_equal(rep, rep2)

orig_manager = model.get_anndata_manager(adata)
assert model.registry_ is not orig_manager.registry
assert model.summary_stats is not orig_manager.summary_stats

adata3 = synthetic_iid()
del adata3.obs["batch"]
# validation catches no batch
with pytest.raises(KeyError):
model.adata = adata3
model.get_latent_representation()


def test_saving_and_loading(save_path):
def legacy_save(
model,
Expand Down Expand Up @@ -451,7 +478,7 @@ def test_save_load_scanvi(legacy=False):
np.testing.assert_array_equal(p1, p2)
assert model.is_trained is True

SCANVI.setup_anndata(adata, "label_0", batch_key="batch", labels_key="labels")
SCANVI.setup_anndata(adata, "labels", "label_0", batch_key="batch")
test_save_load_scanvi(legacy=True)
test_save_load_scanvi()
# Test load prioritizes newer save paradigm and thus mismatches legacy save.
Expand Down Expand Up @@ -721,9 +748,9 @@ def test_scanvi(save_path):
adata = synthetic_iid()
SCANVI.setup_anndata(
adata,
"labels",
"label_0",
batch_key="batch",
labels_key="labels",
)
model = SCANVI(adata, n_latent=10)
model.train(1, train_size=0.5, check_val_every_n_epoch=1)
Expand All @@ -750,7 +777,10 @@ def test_scanvi(save_path):
unknown_label = "asdf"
a = scvi.data.synthetic_iid()
scvi.model.SCANVI.setup_anndata(
a, unknown_label, batch_key="batch", labels_key="labels"
a,
"labels",
unknown_label,
batch_key="batch",
)
m = scvi.model.SCANVI(a)
m.train(1)
Expand All @@ -759,7 +789,10 @@ def test_scanvi(save_path):
unknown_label = "label_0"
a = scvi.data.synthetic_iid()
scvi.model.SCANVI.setup_anndata(
a, unknown_label, batch_key="batch", labels_key="labels"
a,
"labels",
unknown_label,
batch_key="batch",
)
m = scvi.model.SCANVI(a)
m.train(1, train_size=0.9)
Expand Down Expand Up @@ -1069,9 +1102,9 @@ def test_multiple_covariates_scvi(save_path):

SCANVI.setup_anndata(
adata,
"labels",
"Unknown",
batch_key="batch",
labels_key="labels",
continuous_covariate_keys=["cont1", "cont2"],
categorical_covariate_keys=["cat1", "cat2"],
)
Expand Down
16 changes: 13 additions & 3 deletions tests/models/test_scarches.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_scanvi_online_update(save_path):
new_labels = adata1.obs.labels.to_numpy()
new_labels[0] = "Unknown"
adata1.obs["labels"] = pd.Categorical(new_labels)
SCANVI.setup_anndata(adata1, "Unknown", batch_key="batch", labels_key="labels")
SCANVI.setup_anndata(adata1, "labels", "Unknown", batch_key="batch")
model = SCANVI(
adata1,
n_latent=n_latent,
Expand All @@ -176,6 +176,7 @@ def test_scanvi_online_update(save_path):
dir_path = os.path.join(save_path, "saved_model/")
model.save(dir_path, overwrite=True)

# query has all missing labels
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])
adata2.obs["labels"] = "Unknown"
Expand All @@ -185,12 +186,21 @@ def test_scanvi_online_update(save_path):
model.get_latent_representation()
model.predict()

# query has no missing labels
adata2 = synthetic_iid()
adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"])

model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True)
model.train(max_epochs=1)
model.get_latent_representation()
model.predict()

# ref has fully-observed labels
n_latent = 5
adata1 = synthetic_iid()
new_labels = adata1.obs.labels.to_numpy()
adata1.obs["labels"] = pd.Categorical(new_labels)
SCANVI.setup_anndata(adata1, "Unknown", batch_key="batch", labels_key="labels")
SCANVI.setup_anndata(adata1, "labels", "Unknown", batch_key="batch")
model = SCANVI(adata1, n_latent=n_latent, encode_covariates=True)
model.train(max_epochs=1, check_val_every_n_epoch=1)
dir_path = os.path.join(save_path, "saved_model/")
Expand Down Expand Up @@ -253,7 +263,7 @@ def test_scanvi_online_update(save_path):
# test saving and loading of online scanvi
a = synthetic_iid()
ref = a[a.obs["labels"] != "label_2"].copy() # only has labels 0 and 1
SCANVI.setup_anndata(ref, "label_2", batch_key="batch", labels_key="labels")
SCANVI.setup_anndata(ref, "labels", "label_2", batch_key="batch")
m = SCANVI(ref)
m.train(max_epochs=1)
m.save(save_path, overwrite=True)
Expand Down

0 comments on commit 0f7456d

Please sign in to comment.