Skip to content

Commit

Permalink
implement unpaired alignment && fix colabfold_search error
Browse files Browse the repository at this point in the history
  • Loading branch information
dohyun-s committed Feb 26, 2023
1 parent 7ad01b8 commit 548fdcc
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 53 deletions.
33 changes: 29 additions & 4 deletions colabfold/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

from colabfold.inputs import (
get_queries_pairwise, unpack_a3ms,
parse_fasta, get_queries,
parse_fasta, get_queries, msa_to_str
)
from colabfold.run_alphafold import set_model_type

from colabfold.download import default_data_dir, download_alphafold_params

import sys
import logging
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -304,17 +305,31 @@ def main():
for x in batch:
if x not in query_seqs_unique:
query_seqs_unique.append(x)
query_seqs_cardinality = [0] * len(query_seqs_unique)
for seq in batch:
seq_idx = query_seqs_unique.index(seq)
query_seqs_cardinality[seq_idx] += 1
use_env = "env" in args.msa_mode or "Environmental" in args.msa_mode
paired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(Path(args.results).joinpath(str(jobname))),
str(Path(args.results).joinpath(str(jobname)+"_paired")),
use_env=use_env,
use_pairwise=True,
use_pairing=True,
host_url=args.host_url,
)

path_o = Path(args.results).joinpath(f"{jobname}_pairwise")
if args.pair_mode == "none" or args.pair_mode == "unpaired" or args.pair_mode == "unpaired_paired":
unpaired_path = Path(args.results).joinpath(str(jobname)+"_unpaired_env")
unpaired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(Path(args.results).joinpath(str(jobname)+"_unpaired")),
use_env=use_env,
use_pairwise=False,
use_pairing=False,
host_url=args.host_url,
)
path_o = Path(args.results).joinpath(f"{jobname}_paired_pairwise")
for filenum in path_o.iterdir():
queries_new = []
if Path(filenum).suffix.lower() == ".a3m":
Expand All @@ -326,6 +341,16 @@ def main():
query_sequence = seqs[0]
a3m_lines = [Path(file).read_text()]
val = int(header[0].split('\t')[1][1:]) - 102
# match paired seq id and unpaired seq id
if args.pair_mode == "none" or "unpaired" or "unpaired_paired":
tmp = '>101\n' + paired_a3m_lines[0].split('>101\n')[val+1]
a3m_lines = [msa_to_str(
[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]
)]
## Another way: do not use msa_to_str and unserialize function rather
## send unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality as arguments..
##
# a3m_lines = [[unpaired_a3m_lines[0], unpaired_a3m_lines[val+1]], [tmp, paired_a3m_lines[val+1]], [batch[0], batch[val+1]], [1, 1]]
queries_new.append((header_first + '_' + headers_list[jobname][val], query_sequence, a3m_lines))

if args.sort_queries_by == "length":
Expand Down
6 changes: 3 additions & 3 deletions colabfold/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def pad_input_multimer(
pad_len: int,
use_templates: bool,
) -> model.features.FeatureDict:
model_config = model_runner.config
shape_schema = {
"aatype": ["num residues placeholder"],
"residue_index": ["num residues placeholder"],
Expand Down Expand Up @@ -696,7 +695,7 @@ def unserialize_msa(
)
prev_query_start += query_len
paired_msa = [""] * len(query_seq_len)
unpaired_msa = None
unpaired_msa = [""] * len(query_seq_len)
already_in = dict()
for i in range(1, len(a3m_lines), 2):
header = a3m_lines[i]
Expand Down Expand Up @@ -734,7 +733,6 @@ def unserialize_msa(
paired_msa[j] += ">" + header_no_faster_split[j] + "\n"
paired_msa[j] += seqs_line[j] + "\n"
else:
unpaired_msa = [""] * len(query_seq_len)
for j, seq in enumerate(seqs_line):
if has_amino_acid[j]:
unpaired_msa[j] += header + "\n"
Expand All @@ -752,6 +750,8 @@ def unserialize_msa(
template_feature = mk_mock_template(query_seq)
template_features.append(template_feature)

if unpaired_msa == [""] * len(query_seq_len):
unpaired_msa = None
return (
unpaired_msa,
paired_msa,
Expand Down
108 changes: 63 additions & 45 deletions colabfold/mmseqs/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,9 @@ def get_queries_pairwise(
df = pandas.read_csv(input_path, sep=sep)
assert "id" in df.columns and "sequence" in df.columns
queries = [
(str(df["id"][0])+'&'+str(seq_id), [df["sequence"][0].upper(),sequence.upper()], None)
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) if i!=0
(str(df["id"][0])+'&'+str(seq_id), sequence.upper(), None)
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False))
]
for i in range(len(queries)):
if len(queries[i][1]) == 1:
queries[i] = (queries[i][0], queries[i][1][0], None)
elif input_path.suffix == ".a3m":
raise NotImplementedError()
elif input_path.suffix in [".fasta", ".faa", ".fa"]:
Expand All @@ -47,9 +44,7 @@ def get_queries_pairwise(
sequence = sequence.upper()
if sequence.count(":") == 0:
# Single sequence
if i==0:
continue
queries.append((headers[0]+'&'+header, [sequences[0],sequence], None))
queries.append((header, sequence, None))
else:
# Complex mode
queries.append((header, sequence.upper().split(":"), None))
Expand Down Expand Up @@ -449,9 +444,9 @@ def main():
args = parser.parse_args()

if args.interaction_scan:
queries, is_complex = get_queries_pairwise(args.query, None)
queries, is_complex = get_queries_pairwise(args.query)
else:
queries, is_complex = get_queries(args.query, None)
queries, is_complex = get_queries(args.query)

queries_unique = []
for job_number, (raw_jobname, query_sequences, a3m_lines) in enumerate(queries):
Expand Down Expand Up @@ -481,10 +476,9 @@ def main():
query_seqs_cardinality,
) in enumerate(queries_unique):
if job_number==0:
f.write(f">{raw_jobname}_0\n{query_sequences[0]}\n")
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
f.write(f">{raw_jobname}_0\n{query_sequences}\n")
else:
f.write(f">{raw_jobname}\n{query_sequences[1]}\n")
f.write(f">{queries_unique[0][0]+'&'+raw_jobname}\n{query_sequences}\n")
else:
with query_file.open("w") as f:
for job_number, (
Expand All @@ -498,18 +492,6 @@ def main():
args.mmseqs,
["createdb", query_file, args.base.joinpath("qdb"), "--shuffle", "0"],
)
with args.base.joinpath("qdb.lookup").open("w") as f:
id = 0
file_number = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
for seq in query_sequences:
f.write(f"{id}\t{raw_jobname}\t{file_number}\n")
id += 1
file_number += 1

mmseqs_search_monomer(
mmseqs=args.mmseqs,
Expand Down Expand Up @@ -542,30 +524,66 @@ def main():
interaction_scan=args.interaction_scan,
)

if args.interaction_scan:
if len(queries_unique) > 1:
for i in range(len(queries_unique)-2):
idx = 2 + i*2
## delete duplicated query files 2.paired, 4.paired...
os.remove(args.base.joinpath(f"{idx}.paired.a3m"))
for j in range(len(queries_unique)-2):
# replace targets' right file name
id1 = j*2 + 3
id2 = j + 2
os.replace(args.base.joinpath(f"{id1}.paired.a3m"), args.base.joinpath(f"{id2}.paired.a3m"))

id = 0
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
if not args.interaction_scan:
for job_number, (
raw_jobname,
query_sequences,
query_seqs_cardinality,
) in enumerate(queries_unique):
unpaired_msa = []
paired_msa = None
if len(query_seqs_cardinality) > 1:
paired_msa = []
else:
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()
if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)
else:
for job_number, _ in enumerate(queries_unique[:-1]):
query_sequences = [queries_unique[0][1], queries_unique[job_number+1][1]]
unpaired_msa = []
paired_msa = []
for seq in query_sequences:
with args.base.joinpath(f"{id}.a3m").open("r") as f:
with args.base.joinpath(f"0.a3m").open("r") as f:
unpaired_msa.append(f.read())
with args.base.joinpath(f"{job_number+1}.a3m").open("r") as f:
unpaired_msa.append(f.read())
args.base.joinpath(f"{id}.a3m").unlink()
if len(query_seqs_cardinality) > 1:
with args.base.joinpath(f"{id}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
args.base.joinpath(f"{id}.paired.a3m").unlink()
id += 1
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality
)
args.base.joinpath(f"{job_number}.a3m").write_text(msa)

with args.base.joinpath(f"0.paired.a3m").open("r") as f:
paired_msa.append(f.read())
with args.base.joinpath(f"{job_number+1}.paired.a3m").open("r") as f:
paired_msa.append(f.read())
msa = msa_to_str(
unpaired_msa, paired_msa, query_sequences, [1,1]
)
args.base.joinpath(f"{job_number}_final.a3m").write_text(msa)
for job_number, _ in enumerate(queries_unique):
args.base.joinpath(f"{job_number}.a3m").unlink()
args.base.joinpath(f"{job_number}.paired.a3m").unlink()
for job_number, _ in enumerate(queries_unique[:-1]):
os.replace(args.base.joinpath(f"{job_number}_final.a3m"), args.base.joinpath(f"{job_number}.a3m"))
query_file.unlink()
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb")])
run_mmseqs(args.mmseqs, ["rmdb", args.base.joinpath("qdb_h")])
Expand Down
10 changes: 9 additions & 1 deletion colabfold/run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,15 @@ def run(
(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality, template_features_) \
= unserialize_msa(a3m_lines, query_sequence)
if not use_templates: template_features = template_features_

## Another way passing argument
##
# (unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality) = a3m_lines
# template_features_ = []
# from colabfold.inputs import mk_mock_template
# for query_seq in query_seqs_unique:
# template_feature = mk_mock_template(query_seq)
# template_features_.append(template_feature)
# if not use_templates: template_features = template_features_
# save a3m
msa = msa_to_str(unpaired_msa, paired_msa, query_seqs_unique, query_seqs_cardinality)
result_dir.joinpath(f"{jobname}.a3m").write_text(msa)
Expand Down

0 comments on commit 548fdcc

Please sign in to comment.