Skip to content

Commit

Permalink
Merge pull request #8 from aljpetri/BubblePopImprovements
Browse files Browse the repository at this point in the history
Merged the changes introduced via BubblePopImprovements into Master
  • Loading branch information
aljpetri authored Sep 5, 2023
2 parents d3bf64c + 9dab035 commit 32f71bd
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 286 deletions.
63 changes: 42 additions & 21 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import tempfile
import pickle
from collections import defaultdict, deque


from pyinstrument import Profiler
from modules import help_functions, GraphGeneration, batch_merging_parallel, IsoformGeneration, SimplifyGraph

D = {chr(i) : min( 10**( - (ord(chr(i)) - 33)/10.0 ), 0.79433) for i in range(128)}
Expand Down Expand Up @@ -152,13 +151,13 @@ def get_minimizers_and_positions(reads, w, k, hash_fcn):
return M


def get_minimizer_combinations_database( M, k, x_low, x_high):
M2 = defaultdict(lambda: defaultdict(lambda :array("I")))
def get_minimizer_combinations_database(M, k, x_low, x_high):
M2 = defaultdict(lambda: defaultdict(lambda: array("I")))
tmp_cnt = 0
forbidden = 'A'*k
for r_id in M:
minimizers = M[r_id]
for (m1,p1), m1_curr_spans in minimizers_comb_iterator(minimizers, k, x_low, x_high):
for (m1,p1), m1_curr_spans in minimizers_comb_iterator(minimizers, k, x_low, x_high):
for (m2, p2) in m1_curr_spans:
if m2 == m1 == forbidden:
continue
Expand All @@ -177,7 +176,7 @@ def get_minimizer_combinations_database( M, k, x_low, x_high):
for m2 in list(M2[m1].keys()):
if len(M2[m1][m2]) > 3:
avg_bundance += len(M2[m1][m2])//3
cnt +=1
cnt += 1
else:
del M2[m1][m2]
singleton_minimzer += 1
Expand All @@ -193,7 +192,7 @@ def minimizers_comb_iterator(minimizers, k, x_low, x_high):
m1_curr_spans = []
for j, (m2, p2) in enumerate(minimizers[i+1:]):
if x_low < p2 - p1 and p2 - p1 <= x_high:
m1_curr_spans.append( (m2, p2) )
m1_curr_spans.append((m2, p2))
# yield (m1,p1), (m2, p2)
elif p2 - p1 > x_high:
break
Expand Down Expand Up @@ -304,13 +303,13 @@ def find_most_supported_span(r_id, m1, p1, m1_curr_spans, minimizer_combinations
seqs.append(p1)
seqs.append(p2)
for relevant_read_id, pos1, pos2 in grouper(relevant_reads, 3): #relevant_reads:
if r_id == relevant_read_id:
if r_id == relevant_read_id:
continue
elif abs((p2-p1)-(pos2-pos1)) < delta_len:
seqs.append(relevant_read_id)
seqs.append(pos1)
seqs.append(pos2)
all_intervals.append( (p1 + k_size, p2, len(seqs)//3, seqs) )
all_intervals.append((p1 + k_size, p2, len(seqs)//3, seqs))


def main(args):
Expand All @@ -328,7 +327,7 @@ def main(args):
args.exact = True
if args.set_w_dynamically:
args.w = args.k + min(7, int(len(all_reads) / 500))
delta_iso_len_3=args.delta_iso_len_3
delta_iso_len_3 = args.delta_iso_len_3
delta_iso_len_5 = args.delta_iso_len_5
work_dir = tempfile.mkdtemp()
print("Temporary workdirektory:", work_dir)
Expand All @@ -337,13 +336,15 @@ def main(args):
x_high = args.xmax
x_low = args.xmin
if args.parallel:
filename=args.fastq.split("/")[-1]
tmp_filename=filename.split("_")
tmp_lastpart=tmp_filename[-1].split(".")
p_batch_id=tmp_lastpart[0]
skipfilename="skip"+p_batch_id+".fa"
filename = args.fastq.split("/")[-1]
tmp_filename = filename.split("_")
tmp_lastpart = tmp_filename[-1].split(".")
p_batch_id = tmp_lastpart[0]
skipfilename = "skip"+p_batch_id+".fa"

for batch_id, reads in enumerate(batch(all_reads, args.max_seqs)):


new_all_reads = {}
if args.parallel:
batch_pickle = str(p_batch_id) + "_batch"
Expand Down Expand Up @@ -492,41 +493,61 @@ def main(args):
print("Generating the graph")
all_batch_reads_dict[batch_id] = new_all_reads
read_len_dict = get_read_lengths(all_reads)
#for key,value in all_intervals_for_graph.items():
# print(key,len(value))
#print(all_intervals_for_graph)

#profiler = Profiler()
#profiler.start()
#generate the graph from the intervals

DG, reads_for_isoforms = GraphGeneration.generateGraphfromIntervals(
all_intervals_for_graph, k_size, delta_len, read_len_dict,new_all_reads)
#profiler.stop()

#profiler.print()
#test for cyclicity of the graph - a status we cannot continue working on -> if cyclic we get an error
#is_cyclic=GraphGeneration.isCyclic(DG)
#is_cyclic = SimplifyGraph.isCyclic(DG)
#if is_cyclic:
# k_size+=1
# w+=1
# eprint("The graph has a cycle - critical error")
# return -1

# print("The graph has a cycle - critical error")
#return -1
#else:
# print("No cycle in graph")
if DEBUG==True:
print("BATCHID",batch_id)
for id, value in all_batch_reads_dict.items():
for other_id,other_val in value.items():
print(id,": ",other_id,":",other_val[0],"::",other_val[1])

mode = args.slow
#profiler = Profiler()
#profiler.start()
#the bubble popping step: We simplify the graph by linearizing all poppable bubbles
SimplifyGraph.simplifyGraph(DG, new_all_reads, work_dir, k_size, delta_len, mode)
#profiler.stop()

#profiler.print()
#TODO: add delta as user parameter possibly?
delta = 0.15
print("Starting to generate Isoforms")

if args.parallel:
batch_id = p_batch_id
#profiler = Profiler()
#profiler.start()
#generation of isoforms from the graph structure
IsoformGeneration.generate_isoforms(DG, new_all_reads, reads_for_isoforms, work_dir, outfolder, batch_id, delta, delta_len, delta_iso_len_3, delta_iso_len_5, max_seqs_to_spoa)
#profiler.stop()

print("Isoforms generated-Starting batch merging ")
if not args.parallel:
print("Merging the batches with linear strategy")
#merges the predictions from different batches
batch_merging_parallel.join_back_via_batch_merging(args.outfolder, delta, args.delta_len, args.delta_iso_len_3, args.delta_iso_len_5,
args.max_seqs_to_spoa, args.iso_abundance)
#batch_merging_parallel.join_back_via_batch_merging(args.outfolder, delta, args.delta_len, args.delta_iso_len_3, args.delta_iso_len_5,
# args.max_seqs_to_spoa, args.iso_abundance)

print("removing temporary workdir")
sys.stdout.close()
shutil.rmtree(work_dir)
Expand Down
Loading

0 comments on commit 32f71bd

Please sign in to comment.