Skip to content

Commit

Permalink
Backport PR scverse#1416: Transfer when copy detected
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong authored and meeseeksmachine committed Mar 10, 2022
1 parent c2cfdca commit aa17083
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
10 changes: 6 additions & 4 deletions scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,12 @@ def get_anndata_manager(

adata_manager = cls._per_instance_manager_store[self.id][adata_id]
if adata_manager.adata is not adata:
raise ValueError(
"The provided AnnData object does not match the AnnData object "
"previously provided for setup. Did you make a copy?"
logger.info(
"AnnData object appears to be a copy. Attempting to transfer setup."
)
_assign_adata_uuid(adata, overwrite=True)
adata_manager = self.adata_manager.transfer_fields(adata)
self._register_manager_for_instance(adata_manager)

return adata_manager

Expand Down Expand Up @@ -355,7 +357,7 @@ def _validate_anndata(
if adata.is_view:
if copy_if_view:
logger.info("Received view of anndata, making copy.")
adata = adata.copy()
adata._init_as_actual(adata.copy())
# Reassign AnnData UUID to produce a separate AnnDataManager.
_assign_adata_uuid(adata, overwrite=True)
else:
Expand Down
14 changes: 13 additions & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_scvi(save_path):
model._validate_anndata(adata2)
model.get_elbo(adata2)

# test automatic transfer_anndata_setup + on a view
# test automatic transfer_anndata_setup on a view
adata = synthetic_iid()
SCVI.setup_anndata(
adata,
Expand All @@ -236,6 +236,18 @@ def test_scvi(save_path):
adata2 = synthetic_iid()
model.get_elbo(adata2[:10])

# test automatic transfer_anndata_setup on a copy
adata = synthetic_iid()
SCVI.setup_anndata(
adata,
batch_key="batch",
labels_key="labels",
)
model = SCVI(adata)
adata2 = adata.copy()
model.get_elbo(adata2)
assert adata.uns[_constants._SCVI_UUID_KEY] != adata2.uns[_constants._SCVI_UUID_KEY]

# test mismatched categories raises ValueError
adata2 = synthetic_iid()
adata2.obs.labels.cat.rename_categories(["a", "b", "c"], inplace=True)
Expand Down

0 comments on commit aa17083

Please sign in to comment.