Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
annashch-insitro committed Jan 23, 2025
0 parents commit 03e2066
Show file tree
Hide file tree
Showing 51 changed files with 18,429 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Byte-compiled / optimized / DLL files
__pycache__/

# Jupyter Notebook
.ipynb_checkpoints
43 changes: 43 additions & 0 deletions GSFA/GSFA_time_complexity.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
library(data.table)
library(tidyverse)
library(Matrix)
library(GSFA)
library(ggplot2)

install.packages('reticulate')
library(reticulate)
use_python("/usr/bin/python3")
py_discover_config()
np <- import("numpy")

npz <- np$load("inhouse_GSFA_inputs.npz", allow_pickle=TRUE)
Y <- npz$get("array1")
G <- npz$get("array2")

print(dim(Y))
print(dim(G))
print("loaded data")

dev_res <- deviance_residual_transform(Y)
top_gene_index <- select_top_devres_genes(dev_res, num_top_genes = 6000)
dev_res_filtered <- dev_res[, top_gene_index]

write.csv(top_gene_index, "inhouse_top_genes.csv")
rm(npz)
rm(Y)
print(dim(dev_res_filtered))
print("processed data")

set.seed(14314)
time_start = Sys.time()
num_cells = 5000
fit <- fit_gsfa_multivar(Y = dev_res_filtered[1:num_cells,], G = G[1:num_cells,],
K = 20,
prior_type = "mixture_normal",
init.method = "svd",
niter = 3000, used_niter = 1000,
verbose = T, return_samples = T)
print(Sys.time()-time_start)
rm(G)
rm(dev_res_filtered)
saveRDS(fit, file = "fitted_inhouse.rds")
8 changes: 8 additions & 0 deletions GSFA/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# GSFA[(Guided Sparse Factor Analysis)] (Guided Sparse Factor Analysis)

```
Yifan Zhou, Kaixuan Luo, Lifan Liang, Mengjie Chen and Xin He. A new Bayesian factor analysis method improves detection of genes and biological processes affected by perturbations in single-cell CRISPR screening. Nature Methods. (2023). doi: 10.1038/s41592-023-02017-4. PMID: 37770710
```

Training scripts for GSFA to use as a benchmark.
Link to paper: [paper](https://www.nature.com/articles/s41592-023-02017-4)
46 changes: 46 additions & 0 deletions GSFA/inhouse_GSFA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
library(data.table)
library(tidyverse)
library(Matrix)
library(GSFA)
library(ggplot2)

install.packages('reticulate')
library(reticulate)
use_python("/usr/bin/python3")
py_discover_config()
np <- import("numpy")

# read in inputs generated by inhouse_gsfa_preprocessing.ipynb
npz <- np$load("inhouse_GSFA_inputs.npz", allow_pickle=TRUE)
Y <- npz$get("array1")
G <- npz$get("array2")

print(dim(Y))
print(dim(G))
print("loaded data")

# GSFA-specific preprocessing of gene expression
dev_res <- deviance_residual_transform(Y)
top_gene_index <- select_top_devres_genes(dev_res, num_top_genes = 6000)
dev_res_filtered <- dev_res[, top_gene_index]
# save for downstream analysis
np$savez("inhouse_GSFA_preprocessed.npz", array1 = dev_res_filtered)
write.csv(top_gene_index, "inhouse_top_genes.csv")
rm(npz)
rm(Y)
print(dim(dev_res_filtered))
print("processed data")

# train and save GSFA model
set.seed(14314)
time_start = Sys.time()
fit <- fit_gsfa_multivar(Y = dev_res_filtered, G = G,
K = 20,
prior_type = "mixture_normal",
init.method = "svd",
niter = 3000, used_niter = 1000,
verbose = T, return_samples = T)
print(Sys.time()-time_start)
rm(G)
rm(dev_res_filtered)
saveRDS(fit, file = "fitted_inhouse.rds")
249 changes: 249 additions & 0 deletions GSFA/inhouse_gsfa_preprocessing.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "0c3f7ba9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dfbdfee5",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"sys.path.append(\"..\")\n",
"from src.Spectra.Spectra_Pert import vectorize_perts\n",
"from utils import (\n",
" filter_noisy_genes,\n",
" generate_k_fold,\n",
" inhouse_preprocess,\n",
" read_aws_h5ad,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "426b0898-25f4-492b-b77c-d9f1fc17010f",
"metadata": {},
"source": [
"### Get train/test splits consistent with other models"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dd10044-66e2-41c3-a724-282c4792f6d8",
"metadata": {},
"outputs": [],
"source": [
"# use anndata generate by ..data_processing/inhouse_prior_graph_preprocessing.ipynb\n",
"unfilterd_adata = read_aws_h5ad(\"path to preprocessed h5ad here\")\n",
"adata = filter_noisy_genes(unfilterd_adata)\n",
"adata = inhouse_preprocess(adata)\n",
"adata.layers[\"logcounts\"] = adata.X.copy()\n",
"adata.X = adata.X.todense()\n",
"gene_network = adata.uns[\"sparse_gene_network\"].todense()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d7e163dc-2a94-41de-b7f5-f717ac47fb4a",
"metadata": {},
"outputs": [],
"source": [
"# powered perturbations\n",
"adata.obs[\"condition\"] = adata.obs[\"condition\"].astype(str)\n",
"adata.obs[\"Treatment\"] = adata.obs[\"Treatment\"].astype(str)\n",
"adata.obs[\"pert_treat\"] = adata.obs[\"condition\"] + \"+\" + adata.obs[\"Treatment\"]\n",
"obs_df = pd.DataFrame(adata.obs[\"pert_treat\"])\n",
"category_counts = obs_df[\"pert_treat\"].value_counts()\n",
"filtered_categories = category_counts[category_counts >= 50].index\n",
"adata = adata[adata.obs[\"pert_treat\"].isin(filtered_categories)]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2e20577d-23ec-4336-a3f5-8fcae53252a8",
"metadata": {},
"outputs": [],
"source": [
"train_idx, val_idx, test_idx = generate_k_fold(\n",
" adata, adata.X, adata.obs[\"condition\"], fold_idx=0\n",
")"
]
},
{
"cell_type": "markdown",
"id": "431abcfa-4929-4c0c-8651-c1f9c35f37f2",
"metadata": {},
"source": [
"### Process GSFA-specifc input"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96e401f3",
"metadata": {},
"outputs": [],
"source": [
"# use inhouse dataset from s3://pert-spectra\n",
"adata = read_aws_h5ad(\n",
" \"s3://pert-spectra/rnaseq565.filtered.actionet.guide_corrected.h5ad\"\n",
")\n",
"adata = inhouse_preprocess(adata)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "76e4b5e7",
"metadata": {},
"outputs": [],
"source": [
"# filter adata to perturbations with at least 50 samples for each treatment\n",
"adata.obs[\"condition\"] = adata.obs[\"condition\"].astype(str)\n",
"adata.obs[\"Treatment\"] = adata.obs[\"Treatment\"].astype(str)\n",
"adata.obs[\"pert_treat\"] = adata.obs[\"condition\"] + \"+\" + adata.obs[\"Treatment\"]\n",
"obs_df = pd.DataFrame(adata.obs[\"pert_treat\"])\n",
"category_counts = obs_df[\"pert_treat\"].value_counts()\n",
"filtered_categories = category_counts[category_counts >= 50].index\n",
"adata = adata[adata.obs[\"pert_treat\"].isin(filtered_categories)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1fdfbf4b",
"metadata": {},
"outputs": [],
"source": [
"# create binary perturbation matrix\n",
"D, pert_labels = vectorize_perts(adata, \"condition\", [\"ctrl\", \"nan\"])\n",
"pert_idx = np.array(\n",
" [\n",
" adata.var_names.get_loc(i.split(\"_\")[1])\n",
" if i.split(\"_\")[1] in adata.var_names\n",
" else -1\n",
" for i in pert_labels\n",
" ]\n",
")\n",
"# add ctrl one-hot-encoding\n",
"ctrl_vector = np.array([1.0 if i == \"ctrl\" else 0.0 for i in adata.obs[\"condition\"]])\n",
"D = np.concatenate([D, ctrl_vector.reshape(len(ctrl_vector), 1)], axis=1).astype(\n",
" np.float32\n",
")\n",
"pert_idx = np.append(pert_idx, [-1, -1])\n",
"pert_labels = pert_labels + [\"ctrl\"]\n",
"print(D.shape)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4ac6b398",
"metadata": {},
"outputs": [],
"source": [
"# subset to kfold and TNFA+ treatment\n",
"D_train = D[train_idx]\n",
"adata_train = adata[train_idx]\n",
"D_train = D_train[adata_train.obs[\"Treatment\"] == \"TNFA+\"]\n",
"adata_train = adata_train[adata_train.obs[\"Treatment\"] == \"TNFA+\"]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0eb35473",
"metadata": {},
"outputs": [],
"source": [
"# subset further for GSFA to run without OOM issues\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"Y, _, G, _ = train_test_split(\n",
" adata_train.layers[\"counts\"],\n",
" D_train,\n",
" test_size=0.2,\n",
" random_state=42,\n",
" stratify=D_train,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "5d496dfd-bbdf-4e90-aba2-a224674d2876",
"metadata": {},
"outputs": [],
"source": [
"# save inputs for GSFA\n",
"np.savez(\"rna565_GSFA_inputs.npz\", array1=Y.todense(), array2=G)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "c8d4ae72-33c8-4fe1-997b-47be95d0084b",
"metadata": {},
"outputs": [],
"source": [
"# save additional perturbation labels for downstream analysis\n",
"np.savez(\"rna565_G_labels.npz\", pert_labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c76eeff-6cbb-4d30-a2ae-a336fd6e1794",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pertspectra",
"language": "python",
"name": "pertspectra"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
13 changes: 13 additions & 0 deletions GSFA/load_inhouse_GSFA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fitted_rna565_GSFA <- readRDS("~/GSFA/fitted_rna565_GSFA.Rds")

Z <- fitted_rna565_GSFA$posterior_means$Z_pm
beta <- fitted_rna565_GSFA$posterior_means$beta_pm
W <- fitted_rna565_GSFA$posterior_means$W_pm
F <- fitted_rna565_GSFA$posterior_means$F_pm
lsfr <- fitted_rna565_GSFA$lfsr

write.csv(Z,"~/GSFA/rna565_gsfa_outputs/Z.csv")
write.csv(beta,"~/GSFA/rna565_gsfa_outputs/beta.csv")
write.csv(W,"~/GSFA/rna565_gsfa_outputs/W.csv")
write.csv(F,"~/GSFA/rna565_gsfa_outputs/F.csv")
write.csv(lsfr,"~/GSFA/rna565_gsfa_outputs/lsfr.csv")
13 changes: 13 additions & 0 deletions GSFA/load_norman_GSFA.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fitted_norman_GSFA <- readRDS("~/GSFA/fitted_norman_GSFA.rds")

Z <- fitted_norman_GSFA$posterior_means$Z_pm
beta <- fitted_norman_GSFA$posterior_means$beta_pm
W <- fitted_norman_GSFA$posterior_means$W_pm
F <- fitted_norman_GSFA$posterior_means$F_pm
lsfr <- fitted_norman_GSFA$lfsr

write.csv(Z,"~/GSFA/norman_gsfa_outputs/Z.csv")
write.csv(beta,"~/GSFA/norman_gsfa_outputs/beta.csv")
write.csv(W,"~/GSFA/norman_gsfa_outputs/W.csv")
write.csv(F,"~/GSFA/norman_gsfa_outputs/F.csv")
write.csv(lsfr,"~/GSFA/norman_gsfa_outputs/lsfr.csv")
Loading

0 comments on commit 03e2066

Please sign in to comment.