diff --git a/scgen/_scgen.py b/scgen/_scgen.py index c86274d..428f85e 100644 --- a/scgen/_scgen.py +++ b/scgen/_scgen.py @@ -274,6 +274,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: corrected = AnnData( self.module.generative(torch.Tensor(all_shared_ann.X))["px"] .cpu() + .detach() .numpy(), obs=all_shared_ann.obs, ) @@ -283,7 +284,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw - corrected.obsm["latent"] = all_shared_ann.X + corrected.obsm["latent"] = all_shared_ann[corrected.obs_names,:].X corrected.obsm["corrected_latent"] = self.get_latent_representation( corrected ) @@ -303,6 +304,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: corrected = AnnData( self.module.generative(torch.Tensor(all_corrected_data.X))["px"] .cpu() + .detach() .numpy(), obs=all_corrected_data.obs, ) @@ -312,7 +314,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw - corrected.obsm["latent"] = all_corrected_data.X + corrected.obsm["latent"] = all_corrected_data[corrected.obs_names,:].X corrected.obsm["corrected_latent"] = self.get_latent_representation( corrected )