diff --git a/docs/conf.py b/docs/conf.py index fee2c08..21280bd 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -46,7 +46,6 @@ "sphinx.ext.autosummary", "scanpydoc.elegant_typehints", "scanpydoc.definition_list_typed_field", - "scanpydoc.autosummary_generate_imported", *[p.stem for p in (HERE / "extensions").glob("*.py")], ] @@ -75,6 +74,7 @@ todo_include_todos = False numpydoc_show_class_members = False annotate_defaults = True +autosummary_generate_imported = True # The master toctree document. master_doc = "index" diff --git a/docs/tutorials/scgen_perturbation_prediction.ipynb b/docs/tutorials/scgen_perturbation_prediction.ipynb index dd39acb..bc456fd 100644 --- a/docs/tutorials/scgen_perturbation_prediction.ipynb +++ b/docs/tutorials/scgen_perturbation_prediction.ipynb @@ -14,7 +14,8 @@ "outputs": [], "source": [ "import sys\n", - "#if branch is stable, will install via pypi, else will install from source\n", + "\n", + "# if branch is stable, will install via pypi, else will install from source\n", "branch = \"stable\"\n", "IN_COLAB = \"google.colab\" in sys.modules\n", "\n", @@ -78,8 +79,10 @@ } ], "source": [ - "train = sc.read(\"./tests/data/train_kang.h5ad\",\n", - " backup_url='https://drive.google.com/uc?id=1r87vhoLLq6PXAYdmyyd89zG90eJOFYLk')" + "train = sc.read(\n", + " \"./tests/data/train_kang.h5ad\",\n", + " backup_url=\"https://drive.google.com/uc?id=1r87vhoLLq6PXAYdmyyd89zG90eJOFYLk\",\n", + ")" ] }, { @@ -95,8 +98,9 @@ "metadata": {}, "outputs": [], "source": [ - "train_new = train[~((train.obs[\"cell_type\"] == \"CD4T\") &\n", - " (train.obs[\"condition\"] == \"stimulated\"))]" + "train_new = train[\n", + " ~((train.obs[\"cell_type\"] == \"CD4T\") & (train.obs[\"condition\"] == \"stimulated\"))\n", + "].copy()" ] }, { @@ -198,10 +202,7 @@ ], "source": [ "model.train(\n", - " max_epochs=100,\n", - " batch_size=32,\n", - " early_stopping=True,\n", - " early_stopping_patience=25\n", + " max_epochs=100, batch_size=32, early_stopping=True, early_stopping_patience=25\n", ")" ] }, @@ -258,8 +259,13 @@ "source": [ "sc.pp.neighbors(latent_adata)\n", "sc.tl.umap(latent_adata)\n", - "sc.pl.umap(latent_adata, color=['condition', 'cell_type'], wspace=0.4, frameon=False,\n", - " save='latentspace_batch32_klw000005_z100__100e.pdf')" + "sc.pl.umap(\n", + " latent_adata,\n", + " color=[\"condition\", \"cell_type\"],\n", + " wspace=0.4,\n", + " frameon=False,\n", + " save=\"latentspace_batch32_klw000005_z100__100e.pdf\",\n", + ")" ] }, { @@ -339,11 +345,9 @@ ], "source": [ "pred, delta = model.predict(\n", - " ctrl_key='control',\n", - " stim_key='stimulated',\n", - " celltype_to_predict='CD4T'\n", + " ctrl_key=\"control\", stim_key=\"stimulated\", celltype_to_predict=\"CD4T\"\n", ")\n", - "pred.obs['condition'] = 'pred'" + "pred.obs[\"condition\"] = \"pred\"" ] }, { @@ -380,8 +384,12 @@ "metadata": {}, "outputs": [], "source": [ - "ctrl_adata = train[((train.obs['cell_type'] == 'CD4T') & (train.obs['condition'] == 'control'))]\n", - "stim_adata = train[((train.obs['cell_type'] == 'CD4T') & (train.obs['condition'] == 'stimulated'))]" + "ctrl_adata = train[\n", + " ((train.obs[\"cell_type\"] == \"CD4T\") & (train.obs[\"condition\"] == \"control\"))\n", + "]\n", + "stim_adata = train[\n", + " ((train.obs[\"cell_type\"] == \"CD4T\") & (train.obs[\"condition\"] == \"stimulated\"))\n", + "]" ] }, { @@ -433,8 +441,12 @@ ], "source": [ "sc.tl.pca(eval_adata)\n", - "sc.pl.pca(eval_adata, color=\"condition\", frameon=False,\n", - " save='pred_stim_b32_klw000005_z100__100e.pdf')" + "sc.pl.pca(\n", + " eval_adata,\n", + " color=\"condition\",\n", + " frameon=False,\n", + " save=\"pred_stim_b32_klw000005_z100__100e.pdf\",\n", + ")" ] }, { @@ -466,7 +478,7 @@ } ], "source": [ - "CD4T = train[train.obs[\"cell_type\"] ==\"CD4T\"]" + "CD4T = train[train.obs[\"cell_type\"] == \"CD4T\"]" ] }, { @@ -527,7 +539,7 @@ " labels={\"x\": \"predicted\", \"y\": \"ground truth\"},\n", " path_to_save=\"./reg_mean1.pdf\",\n", " show=True,\n", - " legend=False\n", + " legend=False,\n", ")" ] }, @@ -567,11 +579,11 @@ " eval_adata,\n", " axis_keys={\"x\": \"pred\", \"y\": \"stimulated\"},\n", " gene_list=diff_genes[:10],\n", - " top_100_genes= diff_genes,\n", - " labels={\"x\": \"predicted\",\"y\": \"ground truth\"},\n", + " top_100_genes=diff_genes,\n", + " labels={\"x\": \"predicted\", \"y\": \"ground truth\"},\n", " path_to_save=\"./reg_mean1.pdf\",\n", " show=True,\n", - " legend=False\n", + " legend=False,\n", ")" ] }, diff --git a/pyproject.toml b/pyproject.toml index 8c918d8..0aa1da8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,8 @@ scanpydoc = {version = ">=0.5", optional = true} scikit-misc = {version = ">=0.1.3", optional = true} scvi-tools = ">=0.20.0" seaborn = ">=0.11" -sphinx = {version = ">=4.1,<4.4", optional = true} +numpy = "<2.0.0" +sphinx = {version = ">=5.0", optional = true} sphinx-autodoc-typehints = {version = "*", optional = true} sphinx-material = {version = "*", optional = true} typing_extensions = {version = "*", python = "<3.8"} diff --git a/scgen/_scgen.py b/scgen/_scgen.py index 428f85e..bdf470a 100644 --- a/scgen/_scgen.py +++ b/scgen/_scgen.py @@ -284,7 +284,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw - corrected.obsm["latent"] = all_shared_ann[corrected.obs_names,:].X + corrected.obsm["latent"] = all_shared_ann[corrected.obs_names, :].X corrected.obsm["corrected_latent"] = self.get_latent_representation( corrected ) @@ -314,7 +314,7 @@ def batch_removal(self, adata: Optional[AnnData] = None) -> AnnData: adata_raw = AnnData(X=adata.raw.X, var=adata.raw.var) adata_raw.obs_names = adata.obs_names corrected.raw = adata_raw - corrected.obsm["latent"] = all_corrected_data[corrected.obs_names,:].X + corrected.obsm["latent"] = all_corrected_data[corrected.obs_names, :].X corrected.obsm["corrected_latent"] = self.get_latent_representation( corrected ) @@ -447,7 +447,7 @@ def reg_mean_plot( x=x, y=y, arrowprops=dict(arrowstyle="->", color="grey", lw=0.5), - force_points=(0.0, 0.0), + force_static=(0.0, 0.0), ) if legend: pyplot.legend(loc="center left", bbox_to_anchor=(1, 0.5))