Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 66 additions & 44 deletions alphafold/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from alphafold.data.tools import jackhmmer
import numpy as np

import concurrent.futures

# Internal import (7716).

FeatureDict = MutableMapping[str, np.ndarray]
Expand Down Expand Up @@ -124,7 +126,8 @@ def __init__(self,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000,
use_precomputed_msas: bool = False):
use_precomputed_msas: bool = False,
n_parallel_msa: int = 1):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful if you were to add a comment here about the number of logical threads involved with a non-default value of n_parallel_msa. Alternatively, if the number of available cores could be supplied, a function could be devised to adjust both this and the n_cpu variables in the tools folder to optimize the run

"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
Expand All @@ -146,35 +149,13 @@ def __init__(self,
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas
self.n_parallel_msa = n_parallel_msa

def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)

def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence):
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
input_fasta_path=input_fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.uniref_max_hits)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
input_fasta_path=input_fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.mgnify_max_hits)
self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,
'sto', self.use_precomputed_msas)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
Expand All @@ -196,29 +177,70 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)
return uniref90_msa, pdb_template_hits

def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir):
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',
self.use_precomputed_msas)

mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
return mgnify_msa

def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir):
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path,
'a3m', self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
return bfd_msa

def jackhmmer_small_bfd_caller(self, input_fasta_path, msa_output_dir):
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
return bfd_msa


def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)


futures = []
if self._use_small_bfd:
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir))
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniclust_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir))

uniref90_msa, pdb_template_hits = futures[0].result()
mgnify_msa = futures[1].result()
bfd_msa = futures[2].result()

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
Expand Down
5 changes: 4 additions & 1 deletion run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@
'recommended to enable if possible. GPUs must be available'
' if this setting is enabled.')

flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.')

FLAGS = flags.FLAGS

MAX_TEMPLATE_HITS = 20
Expand Down Expand Up @@ -346,7 +348,8 @@ def main(argv):
template_searcher=template_searcher,
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
use_precomputed_msas=FLAGS.use_precomputed_msas)
use_precomputed_msas=FLAGS.use_precomputed_msas,
n_parallel_msa=FLAGS.n_parallel_msa)

if run_multimer_system:
num_predictions_per_model = FLAGS.num_multimer_predictions_per_model
Expand Down