Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 70 additions & 38 deletions alphafold/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from alphafold.data.tools import jackhmmer
import numpy as np

import concurrent.futures

# Internal import (7716).

FeatureDict = MutableMapping[str, np.ndarray]
Expand Down Expand Up @@ -124,7 +126,8 @@ def __init__(self,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000,
use_precomputed_msas: bool = False):
use_precomputed_msas: bool = False,
n_parallel_msa: int = 1):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful if you were to add a comment here about the number of logical threads involved with a non-default value of n_parallel_msa. Alternatively, if the number of available cores could be supplied, a function could be devised to adjust both this and the n_cpu variables in the tools folder to optimize the run

"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
Expand All @@ -146,19 +149,9 @@ def __init__(self,
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas
self.n_parallel_msa = n_parallel_msa

def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)

def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence):
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
Expand All @@ -167,14 +160,6 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.uniref_max_hits)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
input_fasta_path=input_fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.mgnify_max_hits)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
Expand All @@ -196,29 +181,76 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)
return uniref90_msa, pdb_template_hits

def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir):
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
msa_runner=self.jackhmmer_mgnify_runner,
input_fasta_path=input_fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.mgnify_max_hits)

mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
return mgnify_msa

def hhblits_bfd_uniref_caller(self, input_fasta_path, msa_output_dir):
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
return bfd_msa

def jackhmmer_small_bfd_caller(self, input_fasta_path, msa_output_dir):
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
return bfd_msa


def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)


futures = []
if self._use_small_bfd:
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
msa_runner=self.jackhmmer_small_bfd_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir))
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
futures.append(executor.submit(self.hhblits_bfd_uniref_caller, input_fasta_path, msa_output_dir))

uniref90_msa, pdb_template_hits = futures[0].result()
mgnify_msa = futures[1].result()
bfd_msa = futures[2].result()

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
Expand Down
2 changes: 2 additions & 0 deletions docker/run_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
'will be owned by this user:group. By default, this is the current user. '
'Valid options are: uid or uid:gid, non-numeric values are not recognised '
'by Docker unless that user has been created within the container.')
flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.')

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -227,6 +228,7 @@ def main(argv):
f'--models_to_relax={FLAGS.models_to_relax}',
f'--use_gpu_relax={use_gpu_relax}',
'--logtostderr',
f'--n_parallel_msa={FLAGS.n_parallel_msa}'
])

client = docker.from_env()
Expand Down
5 changes: 4 additions & 1 deletion run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class ModelsToRelax(enum.Enum):
'recommended to enable if possible. GPUs must be available'
' if this setting is enabled.')

flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.')

FLAGS = flags.FLAGS

MAX_TEMPLATE_HITS = 20
Expand Down Expand Up @@ -394,7 +396,8 @@ def main(argv):
template_searcher=template_searcher,
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
use_precomputed_msas=FLAGS.use_precomputed_msas)
use_precomputed_msas=FLAGS.use_precomputed_msas,
n_parallel_msa=FLAGS.n_parallel_msa)

if run_multimer_system:
num_predictions_per_model = FLAGS.num_multimer_predictions_per_model
Expand Down