Skip to content

Commit 59a9712

Browse files
committed
Added Markov process script.
1 parent c3cb0ed commit 59a9712

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

neat/models/markov.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import pysam
2+
import numpy as np
3+
import pandas as pd
4+
import pathlib
5+
import matplotlib.pyplot as plt
6+
import seaborn as sns
7+
from scipy.stats import norm
8+
9+
def make_qual_score_list(bam_file):
10+
'''Takes an input BAM file and creates lists of quality scores. This becomes a data frame, which will be
11+
pre-processed for Markov chain analysis.'''
12+
13+
index = f'{bam_file}.bai'
14+
15+
if not pathlib.Path(index).exists():
16+
print('No index found, creating one.')
17+
18+
pysam.index(bam_file)
19+
20+
file_to_parse = pysam.AlignmentFile(bam_file, 'rb', check_sq=False)
21+
num_recs = file_to_parse.count()
22+
print(f'{num_recs} records to parse')
23+
24+
modulo = round(num_recs / 9)
25+
26+
qual_list = []
27+
i = 0
28+
j = 0
29+
30+
def print_update(number, factor, percent):
31+
if number % factor == 0:
32+
percent += 10
33+
print(f'{percent}% complete', end='\r')
34+
return percent
35+
36+
print('Parsing file')
37+
38+
for item in file_to_parse.fetch():
39+
if item.is_unmapped or len(item.seq) != 249 or 'S' in item.cigarstring:
40+
i += 1
41+
j = print_update(i, modulo, j)
42+
continue
43+
44+
# Mapping quality scores
45+
46+
align_qual = item.query_alignment_qualities
47+
48+
# Append to master lists
49+
50+
qual_list.append(align_qual)
51+
i += 1
52+
j = print_update(i, modulo, j)
53+
54+
print(f'100% complete')
55+
file_to_parse.close()
56+
57+
# Turn list of lists into a dataframe
58+
59+
quality_df = pd.DataFrame(qual_list)
60+
61+
# Pre-processing - fill in missing data and sample 1% of reads
62+
63+
quality_df = quality_df.fillna(0)
64+
quality_df = quality_df.sample(frac=0.01, axis=0, random_state=42)
65+
66+
return quality_df
67+
68+
def estimate_transition_probabilities():
69+
70+
# Define the probabilities for transition states based on a normal distribution
71+
72+
std_dev = 1
73+
transition_probs = {
74+
-3: norm.pdf(-3, 0, std_dev),
75+
-2: norm.pdf(-2, 0, std_dev),
76+
-1: norm.pdf(-1, 0, std_dev),
77+
0: norm.pdf(0, 0, std_dev),
78+
1: norm.pdf(1, 0, std_dev),
79+
2: norm.pdf(2, 0, std_dev),
80+
3: norm.pdf(3, 0, std_dev)
81+
}
82+
83+
# Normalize the probabilities to sum to 1
84+
85+
total_prob = sum(transition_probs.values())
86+
for k in transition_probs:
87+
transition_probs[k] /= total_prob
88+
89+
return transition_probs
90+
91+
def apply_markov_chain(quality_df, L=249):
92+
93+
transition_probs = estimate_transition_probabilities()
94+
num_rows, num_cols = quality_df.shape
95+
96+
markov_preds = []
97+
98+
for row in quality_df.iterrows():
99+
qualities = row[1].values
100+
pred_qualities = np.zeros_like(qualities)
101+
pred_qualities[0] = qualities[0] # initial state
102+
103+
for i in range(1, L):
104+
prev_quality = pred_qualities[i - 1]
105+
transitions = list(transition_probs.keys())
106+
probabilities = list(transition_probs.values())
107+
next_quality = np.random.choice(transitions, p=probabilities)
108+
pred_qualities[i] = max(0, prev_quality + next_quality) # ensuring no negative qualities
109+
110+
markov_preds.append(pred_qualities)
111+
112+
markov_preds_df = pd.DataFrame(markov_preds)
113+
114+
# Apply final transformations
115+
116+
edge_len = int(L * 0.05)
117+
mid_start = int(L * 0.40)
118+
mid_end = int(L * 0.60)
119+
120+
markov_preds_df.iloc[:, :edge_len] -= 5
121+
markov_preds_df.iloc[:, :edge_len] = markov_preds_df.iloc[:, :edge_len].clip(lower=0)
122+
123+
markov_preds_df.iloc[:, -edge_len:] -= 5
124+
markov_preds_df.iloc[:, -edge_len:] = markov_preds_df.iloc[:, -edge_len:].clip(lower=0)
125+
126+
markov_preds_df.iloc[:, mid_start:mid_end] += 1
127+
128+
return markov_preds_df
129+
130+
def plot_heatmap(y_preds_df, file_path):
131+
'''Takes a dataframe of predicted quality scores and plots a seaborn heatmap to visualize them.'''
132+
133+
sns.heatmap(y_preds_df, vmin=0, vmax=40, cmap='viridis')
134+
plt.savefig(file_path)
135+
print('Heatmap plotted')
136+
137+
# Example usage
138+
139+
bam_file = '/projects/bclt/neat_data/H1N1_new.bam'
140+
test_df = make_qual_score_list(bam_file)
141+
markov_preds_df = apply_markov_chain(test_df)
142+
plot_heatmap(markov_preds_df, 'markov_chain_heatmap.svg')
143+

0 commit comments

Comments
 (0)