diff --git a/scvi/external/solo/_model.py b/scvi/external/solo/_model.py index 779eecdd9b..0017a6fa18 100644 --- a/scvi/external/solo/_model.py +++ b/scvi/external/solo/_model.py @@ -235,6 +235,15 @@ def create_doublets( "sim_doublet_{}".format(i) for i in range(num_doublets) ] + # if adata setup with a layer, need to add layer to doublets adata + data_registry = adata.uns["_scvi"]["data_registry"] + x_loc = data_registry[_CONSTANTS.X_KEY]["attr_name"] + layer = ( + data_registry[_CONSTANTS.X_KEY]["attr_key"] if x_loc == "layers" else None + ) + if layer is not None: + doublets_ad.layers[layer] = doublets + return doublets_ad def train( diff --git a/tests/external/test_solo.py b/tests/external/test_solo.py index cf1955e7da..121b005b04 100644 --- a/tests/external/test_solo.py +++ b/tests/external/test_solo.py @@ -25,7 +25,8 @@ def test_solo(save_path): def test_solo_multiple_batch(save_path): n_latent = 5 adata = synthetic_iid() - setup_anndata(adata, batch_key="batch") + adata.layers["my_layer"] = adata.X.copy() + setup_anndata(adata, layer="my_layer", batch_key="batch") model = SCVI(adata, n_latent=n_latent) model.train(1, check_val_every_n_epoch=1, train_size=0.5)