From bdb646cc08cfed337235a75d37a9fd0d023a9d90 Mon Sep 17 00:00:00 2001 From: Karin Hrovatin <47607471+Hrovatin@users.noreply.github.com> Date: Sun, 1 Oct 2023 13:11:29 +0200 Subject: [PATCH] fix cell label assignment issue in batch removal https://github.com/theislab/scgen/issues/86 --- scgen/_scgen.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scgen/_scgen.py b/scgen/_scgen.py index c86274d..428f85e 100644 --- a/scgen/_scgen.py +++ b/scgen/_scgen.py @@ -274,6 +274,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: corrected = AnnData( self.module.generative(torch.Tensor(all_shared_ann.X))["px"] .cpu() + .detach() .numpy(), obs=all_shared_ann.obs, ) @@ -283,7 +284,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw - corrected.obsm["latent"] = all_shared_ann.X + corrected.obsm["latent"] = all_shared_ann[corrected.obs_names,:].X corrected.obsm["corrected_latent"] = self.get_latent_representation( corrected ) @@ -303,6 +304,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: corrected = AnnData( self.module.generative(torch.Tensor(all_corrected_data.X))["px"] .cpu() + .detach() .numpy(), obs=all_corrected_data.obs, ) @@ -312,7 +314,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw - corrected.obsm["latent"] = all_corrected_data.X + corrected.obsm["latent"] = all_corrected_data[corrected.obs_names,:].X corrected.obsm["corrected_latent"] = self.get_latent_representation( corrected )