-
Notifications
You must be signed in to change notification settings - Fork 0
/
peakVI_process.py
66 lines (50 loc) · 1.78 KB
/
peakVI_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Script for running PeakVI on a 10x formatted dataset."""
import os
from typing import Sequence
from absl import app
from absl import flags
import anndata
import pandas as pd
import scipy.io
import scipy.sparse
import scvi
import time
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_string('input_path', None, 'Path to the 10x formatted folder.')
flags.DEFINE_string('output_path', None, 'Path to the output directory.')
def create_anndata(path: os.PathLike) -> anndata.AnnData:
"""Creates anndata object from raw data.
Args:
path: Path to the 10x formatted input files.
Returns:
anndata object for the experiment.
"""
with tf.io.gfile.GFile(os.path.join(path, 'matrix.mtx'), mode='rb') as f:
matrix = scipy.io.mmread(f)
matrix = scipy.sparse.csr_matrix(matrix)
adata = anndata.AnnData(matrix)
adata = adata.transpose()
with tf.io.gfile.GFile(os.path.join(path, 'barcodes.tsv'), mode='r') as f:
barcodes = pd.read_csv(f, sep='\t', header=None)[0]
adata.obs_names = barcodes
with tf.io.gfile.GFile(os.path.join(path, 'bins.tsv'), mode='r') as f:
bins = pd.read_csv(f, sep='\t', header=None)[0]
adata.var_names = bins
return adata
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
adata = create_anndata(FLAGS.input_path)
scvi.model.PEAKVI.setup_anndata(adata)
before = time.time()
vae = scvi.model.PEAKVI(adata)
vae.train()
dr = pd.DataFrame(vae.get_latent_representation(), index=adata.obs_names)
delta = time.time() - before
print(f'Time to run PeakVI {FLAGS.input_path} {delta}')
tf.io.gfile.makedirs(FLAGS.output_path)
with tf.io.gfile.GFile(os.path.join(FLAGS.output_path, 'peakVI.csv'), 'w') as f:
dr.to_csv(f)
if __name__ == '__main__':
app.run(main)