-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 03e2066
Showing
51 changed files
with
18,429 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
library(data.table) | ||
library(tidyverse) | ||
library(Matrix) | ||
library(GSFA) | ||
library(ggplot2) | ||
|
||
install.packages('reticulate') | ||
library(reticulate) | ||
use_python("/usr/bin/python3") | ||
py_discover_config() | ||
np <- import("numpy") | ||
|
||
npz <- np$load("inhouse_GSFA_inputs.npz", allow_pickle=TRUE) | ||
Y <- npz$get("array1") | ||
G <- npz$get("array2") | ||
|
||
print(dim(Y)) | ||
print(dim(G)) | ||
print("loaded data") | ||
|
||
dev_res <- deviance_residual_transform(Y) | ||
top_gene_index <- select_top_devres_genes(dev_res, num_top_genes = 6000) | ||
dev_res_filtered <- dev_res[, top_gene_index] | ||
|
||
write.csv(top_gene_index, "inhouse_top_genes.csv") | ||
rm(npz) | ||
rm(Y) | ||
print(dim(dev_res_filtered)) | ||
print("processed data") | ||
|
||
set.seed(14314) | ||
time_start = Sys.time() | ||
num_cells = 5000 | ||
fit <- fit_gsfa_multivar(Y = dev_res_filtered[1:num_cells,], G = G[1:num_cells,], | ||
K = 20, | ||
prior_type = "mixture_normal", | ||
init.method = "svd", | ||
niter = 3000, used_niter = 1000, | ||
verbose = T, return_samples = T) | ||
print(Sys.time()-time_start) | ||
rm(G) | ||
rm(dev_res_filtered) | ||
saveRDS(fit, file = "fitted_inhouse.rds") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# GSFA[(Guided Sparse Factor Analysis)] (Guided Sparse Factor Analysis) | ||
|
||
``` | ||
Yifan Zhou, Kaixuan Luo, Lifan Liang, Mengjie Chen and Xin He. A new Bayesian factor analysis method improves detection of genes and biological processes affected by perturbations in single-cell CRISPR screening. Nature Methods. (2023). doi: 10.1038/s41592-023-02017-4. PMID: 37770710 | ||
``` | ||
|
||
Training scripts for GSFA to use as a benchmark. | ||
Link to paper: [paper](https://www.nature.com/articles/s41592-023-02017-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
library(data.table) | ||
library(tidyverse) | ||
library(Matrix) | ||
library(GSFA) | ||
library(ggplot2) | ||
|
||
install.packages('reticulate') | ||
library(reticulate) | ||
use_python("/usr/bin/python3") | ||
py_discover_config() | ||
np <- import("numpy") | ||
|
||
# read in inputs generated by inhouse_gsfa_preprocessing.ipynb | ||
npz <- np$load("inhouse_GSFA_inputs.npz", allow_pickle=TRUE) | ||
Y <- npz$get("array1") | ||
G <- npz$get("array2") | ||
|
||
print(dim(Y)) | ||
print(dim(G)) | ||
print("loaded data") | ||
|
||
# GSFA-specific preprocessing of gene expression | ||
dev_res <- deviance_residual_transform(Y) | ||
top_gene_index <- select_top_devres_genes(dev_res, num_top_genes = 6000) | ||
dev_res_filtered <- dev_res[, top_gene_index] | ||
# save for downstream analysis | ||
np$savez("inhouse_GSFA_preprocessed.npz", array1 = dev_res_filtered) | ||
write.csv(top_gene_index, "inhouse_top_genes.csv") | ||
rm(npz) | ||
rm(Y) | ||
print(dim(dev_res_filtered)) | ||
print("processed data") | ||
|
||
# train and save GSFA model | ||
set.seed(14314) | ||
time_start = Sys.time() | ||
fit <- fit_gsfa_multivar(Y = dev_res_filtered, G = G, | ||
K = 20, | ||
prior_type = "mixture_normal", | ||
init.method = "svd", | ||
niter = 3000, used_niter = 1000, | ||
verbose = T, return_samples = T) | ||
print(Sys.time()-time_start) | ||
rm(G) | ||
rm(dev_res_filtered) | ||
saveRDS(fit, file = "fitted_inhouse.rds") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "0c3f7ba9", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"The autoreload extension is already loaded. To reload it, use:\n", | ||
" %reload_ext autoreload\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "dfbdfee5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"sys.path.append(\"..\")\n", | ||
"from src.Spectra.Spectra_Pert import vectorize_perts\n", | ||
"from utils import (\n", | ||
" filter_noisy_genes,\n", | ||
" generate_k_fold,\n", | ||
" inhouse_preprocess,\n", | ||
" read_aws_h5ad,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "426b0898-25f4-492b-b77c-d9f1fc17010f", | ||
"metadata": {}, | ||
"source": [ | ||
"### Get train/test splits consistent with other models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1dd10044-66e2-41c3-a724-282c4792f6d8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# use anndata generate by ..data_processing/inhouse_prior_graph_preprocessing.ipynb\n", | ||
"unfilterd_adata = read_aws_h5ad(\"path to preprocessed h5ad here\")\n", | ||
"adata = filter_noisy_genes(unfilterd_adata)\n", | ||
"adata = inhouse_preprocess(adata)\n", | ||
"adata.layers[\"logcounts\"] = adata.X.copy()\n", | ||
"adata.X = adata.X.todense()\n", | ||
"gene_network = adata.uns[\"sparse_gene_network\"].todense()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "d7e163dc-2a94-41de-b7f5-f717ac47fb4a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# powered perturbations\n", | ||
"adata.obs[\"condition\"] = adata.obs[\"condition\"].astype(str)\n", | ||
"adata.obs[\"Treatment\"] = adata.obs[\"Treatment\"].astype(str)\n", | ||
"adata.obs[\"pert_treat\"] = adata.obs[\"condition\"] + \"+\" + adata.obs[\"Treatment\"]\n", | ||
"obs_df = pd.DataFrame(adata.obs[\"pert_treat\"])\n", | ||
"category_counts = obs_df[\"pert_treat\"].value_counts()\n", | ||
"filtered_categories = category_counts[category_counts >= 50].index\n", | ||
"adata = adata[adata.obs[\"pert_treat\"].isin(filtered_categories)]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "2e20577d-23ec-4336-a3f5-8fcae53252a8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_idx, val_idx, test_idx = generate_k_fold(\n", | ||
" adata, adata.X, adata.obs[\"condition\"], fold_idx=0\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "431abcfa-4929-4c0c-8651-c1f9c35f37f2", | ||
"metadata": {}, | ||
"source": [ | ||
"### Process GSFA-specifc input" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "96e401f3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# use inhouse dataset from s3://pert-spectra\n", | ||
"adata = read_aws_h5ad(\n", | ||
" \"s3://pert-spectra/rnaseq565.filtered.actionet.guide_corrected.h5ad\"\n", | ||
")\n", | ||
"adata = inhouse_preprocess(adata)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "76e4b5e7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# filter adata to perturbations with at least 50 samples for each treatment\n", | ||
"adata.obs[\"condition\"] = adata.obs[\"condition\"].astype(str)\n", | ||
"adata.obs[\"Treatment\"] = adata.obs[\"Treatment\"].astype(str)\n", | ||
"adata.obs[\"pert_treat\"] = adata.obs[\"condition\"] + \"+\" + adata.obs[\"Treatment\"]\n", | ||
"obs_df = pd.DataFrame(adata.obs[\"pert_treat\"])\n", | ||
"category_counts = obs_df[\"pert_treat\"].value_counts()\n", | ||
"filtered_categories = category_counts[category_counts >= 50].index\n", | ||
"adata = adata[adata.obs[\"pert_treat\"].isin(filtered_categories)]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1fdfbf4b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# create binary perturbation matrix\n", | ||
"D, pert_labels = vectorize_perts(adata, \"condition\", [\"ctrl\", \"nan\"])\n", | ||
"pert_idx = np.array(\n", | ||
" [\n", | ||
" adata.var_names.get_loc(i.split(\"_\")[1])\n", | ||
" if i.split(\"_\")[1] in adata.var_names\n", | ||
" else -1\n", | ||
" for i in pert_labels\n", | ||
" ]\n", | ||
")\n", | ||
"# add ctrl one-hot-encoding\n", | ||
"ctrl_vector = np.array([1.0 if i == \"ctrl\" else 0.0 for i in adata.obs[\"condition\"]])\n", | ||
"D = np.concatenate([D, ctrl_vector.reshape(len(ctrl_vector), 1)], axis=1).astype(\n", | ||
" np.float32\n", | ||
")\n", | ||
"pert_idx = np.append(pert_idx, [-1, -1])\n", | ||
"pert_labels = pert_labels + [\"ctrl\"]\n", | ||
"print(D.shape)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "4ac6b398", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# subset to kfold and TNFA+ treatment\n", | ||
"D_train = D[train_idx]\n", | ||
"adata_train = adata[train_idx]\n", | ||
"D_train = D_train[adata_train.obs[\"Treatment\"] == \"TNFA+\"]\n", | ||
"adata_train = adata_train[adata_train.obs[\"Treatment\"] == \"TNFA+\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"id": "0eb35473", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# subset further for GSFA to run without OOM issues\n", | ||
"from sklearn.model_selection import train_test_split\n", | ||
"\n", | ||
"Y, _, G, _ = train_test_split(\n", | ||
" adata_train.layers[\"counts\"],\n", | ||
" D_train,\n", | ||
" test_size=0.2,\n", | ||
" random_state=42,\n", | ||
" stratify=D_train,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "5d496dfd-bbdf-4e90-aba2-a224674d2876", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# save inputs for GSFA\n", | ||
"np.savez(\"rna565_GSFA_inputs.npz\", array1=Y.todense(), array2=G)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"id": "c8d4ae72-33c8-4fe1-997b-47be95d0084b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# save additional perturbation labels for downstream analysis\n", | ||
"np.savez(\"rna565_G_labels.npz\", pert_labels)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3c76eeff-6cbb-4d30-a2ae-a336fd6e1794", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "pertspectra", | ||
"language": "python", | ||
"name": "pertspectra" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
fitted_rna565_GSFA <- readRDS("~/GSFA/fitted_rna565_GSFA.Rds") | ||
|
||
Z <- fitted_rna565_GSFA$posterior_means$Z_pm | ||
beta <- fitted_rna565_GSFA$posterior_means$beta_pm | ||
W <- fitted_rna565_GSFA$posterior_means$W_pm | ||
F <- fitted_rna565_GSFA$posterior_means$F_pm | ||
lsfr <- fitted_rna565_GSFA$lfsr | ||
|
||
write.csv(Z,"~/GSFA/rna565_gsfa_outputs/Z.csv") | ||
write.csv(beta,"~/GSFA/rna565_gsfa_outputs/beta.csv") | ||
write.csv(W,"~/GSFA/rna565_gsfa_outputs/W.csv") | ||
write.csv(F,"~/GSFA/rna565_gsfa_outputs/F.csv") | ||
write.csv(lsfr,"~/GSFA/rna565_gsfa_outputs/lsfr.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
fitted_norman_GSFA <- readRDS("~/GSFA/fitted_norman_GSFA.rds") | ||
|
||
Z <- fitted_norman_GSFA$posterior_means$Z_pm | ||
beta <- fitted_norman_GSFA$posterior_means$beta_pm | ||
W <- fitted_norman_GSFA$posterior_means$W_pm | ||
F <- fitted_norman_GSFA$posterior_means$F_pm | ||
lsfr <- fitted_norman_GSFA$lfsr | ||
|
||
write.csv(Z,"~/GSFA/norman_gsfa_outputs/Z.csv") | ||
write.csv(beta,"~/GSFA/norman_gsfa_outputs/beta.csv") | ||
write.csv(W,"~/GSFA/norman_gsfa_outputs/W.csv") | ||
write.csv(F,"~/GSFA/norman_gsfa_outputs/F.csv") | ||
write.csv(lsfr,"~/GSFA/norman_gsfa_outputs/lsfr.csv") |
Oops, something went wrong.