Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ components = [
'AlphaFold-2.1.0_fix-scp-path.patch',
'AlphaFold-2.0.1_setup_rm_tfcpu.patch',
'AlphaFold-2.3.1_use_openmm_7.7.0.patch',
'AlphaFold-2.3.1_add-run_features_only-option.patch',
'AlphaFold-2.3.1_parallel-execution-of-MSA-tools.patch',
],
'checksums': [
'1161b2609fa896b16399b900ec2b813e5a0b363fe4e2b26bd826953ba234736a', # v2.3.1.tar.gz
Expand All @@ -72,7 +74,10 @@ components = [
'1a2e4e843bd9a4d15ee39e6c37cc63ba281311cc7a0a5610f0e43b52ef93faac', # AlphaFold-2.0.1_setup_rm_tfcpu.patch
# AlphaFold-2.3.1_use_openmm_7.7.0.patch
'd800bb085deac7edbe2d72916c1194043964aaf51b88a3b5a5016ab68a1090ec',

# AlphaFold-2.3.1_add-run_features_only-option.patch
'9221a277f3c966d50f3c07b27eec3b0912dbdd953ed4eb9f2d0e6ffa8d50cfd5',
# AlphaFold-2.3.1_parallel-execution-of-MSA-tools.patch
'079840026a734efd67f790c7ae0f95cf3958ae7f35429671445c060668080f79',
],
'start_dir': 'alphafold-%(version)s',
'use_pip': True,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
From d2daf93b77048acb98cd470e4f6670660b394c8d Mon Sep 17 00:00:00 2001
From: Viktor Rehnberg <viktor.rehnberg@gmail.com>
Date: Wed, 15 May 2024 13:09:14 +0000
Subject: [PATCH 1/1] Run features only option

Adds a flat `--run_feature_only` and some logic that means that if
`features.pkl` already exists, this step will be skipped.

To a large degree taken from https://github.com/Zuricho/ParallelFold/

---
run_alphafold.py | 27 +++++++++++++++++++++------
run_alphafold_test.py | 3 ++-
2 files changed, 23 insertions(+), 7 deletions(-)

diff --git a/run_alphafold.py b/run_alphafold.py
index 72416e0..7ca81d0 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -168,6 +168,8 @@ flags.DEFINE_boolean('use_gpu_relax', use_gpu_relax, 'Whether to relax on GPU. '
'Relax on GPU can be much faster than CPU, so it is '
'recommended to enable if possible. GPUs must be available'
' if this setting is enabled.')
+flags.DEFINE_boolean('run_feature_only', False, 'Calculate MSA and template to generate '
+ 'feature and then stop.')

FLAGS = flags.FLAGS

@@ -196,7 +198,8 @@ def predict_structure(
model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation,
benchmark: bool,
- random_seed: int):
+ random_seed: int,
+ run_feature_only: bool):
"""Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name)
timings = {}
@@ -209,16 +212,27 @@ def predict_structure(

# Get features.
t_0 = time.time()
- feature_dict = data_pipeline.process(
- input_fasta_path=fasta_path,
- msa_output_dir=msa_output_dir)
- timings['features'] = time.time() - t_0
+ features_output_path = os.path.join(output_dir, 'features.pkl')
+
+ # If we already have feature.pkl file, skip the MSA and template finding step
+ if os.path.exists(features_output_path):
+ feature_dict = pickle.load(open(features_output_path, 'rb'))
+
+ else:
+ feature_dict = data_pipeline.process(
+ input_fasta_path=fasta_path,
+ msa_output_dir=msa_output_dir)

# Write out features as a pickled dictionary.
features_output_path = os.path.join(output_dir, 'features.pkl')
with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)

+ timings['features'] = time.time() - t_0
+
+ if run_feature_only: # if not run_feature, skip the rest of the function
+ return 0
+
unrelaxed_pdbs = {}
relaxed_pdbs = {}
relax_metrics = {}
@@ -457,7 +471,8 @@ def main(argv):
model_runners=model_runners,
amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark,
- random_seed=random_seed)
+ random_seed=random_seed,
+ run_feature_only=FLAGS.run_feature_only)


if __name__ == '__main__':
diff --git a/run_alphafold_test.py b/run_alphafold_test.py
index b91189c..efe4dd6 100644
--- a/run_alphafold_test.py
+++ b/run_alphafold_test.py
@@ -74,7 +74,8 @@ class RunAlphafoldTest(parameterized.TestCase):
model_runners={'model1': model_runner_mock},
amber_relaxer=amber_relaxer_mock if do_relax else None,
benchmark=False,
- random_seed=0)
+ random_seed=0,
+ run_feature_only=False)

base_output_files = os.listdir(out_dir)
self.assertIn('target.fasta', base_output_files)
--
2.39.3

Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
From 50086f389a90ee797f336d23564e357bec127bde Mon Sep 17 00:00:00 2001
From: Viktor Rehnberg <viktor.rehnberg@gmail.com>
Date: Wed, 15 May 2024 15:26:32 +0000
Subject: [PATCH 1/1] Parallel execution of MSA tools

MSA search is done by three tools, this adds an option to do these in
parallel. From https://github.com/google-deepmind/alphafold/pull/399

---
alphafold/data/pipeline.py | 108 ++++++++++++++++++++++++-------------
run_alphafold.py | 4 +-
2 files changed, 73 insertions(+), 39 deletions(-)

diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py
index a90eb57..927aa33 100644
--- a/alphafold/data/pipeline.py
+++ b/alphafold/data/pipeline.py
@@ -27,6 +27,8 @@ from alphafold.data.tools import hmmsearch
from alphafold.data.tools import jackhmmer
import numpy as np

+import concurrent.futures
+
# Internal import (7716).

FeatureDict = MutableMapping[str, np.ndarray]
@@ -124,7 +126,8 @@ class DataPipeline:
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000,
- use_precomputed_msas: bool = False):
+ use_precomputed_msas: bool = False,
+ n_parallel_msa: int = 1):
"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
@@ -146,19 +149,9 @@ class DataPipeline:
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas
+ self.n_parallel_msa = n_parallel_msa

- def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
- """Runs alignment tools on the input sequence and creates features."""
- with open(input_fasta_path) as f:
- input_fasta_str = f.read()
- input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
- if len(input_seqs) != 1:
- raise ValueError(
- f'More than one input sequence found in {input_fasta_path}.')
- input_sequence = input_seqs[0]
- input_description = input_descs[0]
- num_res = len(input_sequence)
-
+ def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence):
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
msa_runner=self.jackhmmer_uniref90_runner,
@@ -167,14 +160,6 @@ class DataPipeline:
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.uniref_max_hits)
- mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
- jackhmmer_mgnify_result = run_msa_tool(
- msa_runner=self.jackhmmer_mgnify_runner,
- input_fasta_path=input_fasta_path,
- msa_out_path=mgnify_out_path,
- msa_format='sto',
- use_precomputed_msas=self.use_precomputed_msas,
- max_sto_sequences=self.mgnify_max_hits)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
@@ -196,29 +181,76 @@ class DataPipeline:
f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
- mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
+ uniref90_msa = parsers.truncate(max_seqs=self.uniref_max_hits)

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)
+ return uniref90_msa, pdb_template_hits
+
+ def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir):
+ mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
+ jackhmmer_mgnify_result = run_msa_tool(
+ msa_runner=self.jackhmmer_mgnify_runner,
+ input_fasta_path=input_fasta_path,
+ msa_out_path=mgnify_out_path,
+ msa_format='sto',
+ use_precomputed_msas=self.use_precomputed_msas,
+ max_sto_sequences=self.mgnify_max_hits)
+
+ mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
+ return mgnify_msa
+
+ def hhblits_bfd_uniref_caller(self, input_fasta_path, msa_output_dir):
+ bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
+ hhblits_bfd_uniref_result = run_msa_tool(
+ msa_runner=self.hhblits_bfd_uniref_runner,
+ input_fasta_path=input_fasta_path,
+ msa_out_path=bfd_out_path,
+ msa_format='a3m',
+ use_precomputed_msas=self.use_precomputed_msas)
+ bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
+ return bfd_msa
+
+ def jackhmmer_small_bfd_caller(self, input_fasta_path, msa_output_dir):
+ bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
+ jackhmmer_small_bfd_result = run_msa_tool(
+ msa_runner=self.jackhmmer_small_bfd_runner,
+ input_fasta_path=input_fasta_path,
+ msa_out_path=bfd_out_path,
+ msa_format='sto',
+ use_precomputed_msas=self.use_precomputed_msas)
+ bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
+ return bfd_msa
+
+
+ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
+ """Runs alignment tools on the input sequence and creates features."""
+ with open(input_fasta_path) as f:
+ input_fasta_str = f.read()
+ input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
+ if len(input_seqs) != 1:
+ raise ValueError(
+ f'More than one input sequence found in {input_fasta_path}.')
+ input_sequence = input_seqs[0]
+ input_description = input_descs[0]
+ num_res = len(input_sequence)
+
+ futures = []

if self._use_small_bfd:
- bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
- jackhmmer_small_bfd_result = run_msa_tool(
- msa_runner=self.jackhmmer_small_bfd_runner,
- input_fasta_path=input_fasta_path,
- msa_out_path=bfd_out_path,
- msa_format='sto',
- use_precomputed_msas=self.use_precomputed_msas)
- bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
+ futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
+ futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
+ futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir))
else:
- bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
- hhblits_bfd_uniref_result = run_msa_tool(
- msa_runner=self.hhblits_bfd_uniref_runner,
- input_fasta_path=input_fasta_path,
- msa_out_path=bfd_out_path,
- msa_format='a3m',
- use_precomputed_msas=self.use_precomputed_msas)
- bfd_msa = parsers.parse_a3m(hhblits_bfd_uniref_result['a3m'])
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
+ futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
+ futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
+ futures.append(executor.submit(self.hhblits_bfd_uniref_caller, input_fasta_path, msa_output_dir))
+
+ uniref90_msa, pdb_template_hits = futures[0].result()
+ mgnify_msa = futures[1].result()
+ bfd_msa = futures[2].result()

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
diff --git a/run_alphafold.py b/run_alphafold.py
index 7ca81d0..d4a0644 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -170,6 +170,7 @@ flags.DEFINE_boolean('use_gpu_relax', use_gpu_relax, 'Whether to relax on GPU. '
' if this setting is enabled.')
flags.DEFINE_boolean('run_feature_only', False, 'Calculate MSA and template to generate '
'feature and then stop.')
+flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.')

FLAGS = flags.FLAGS

@@ -414,7 +415,8 @@ def main(argv):
template_searcher=template_searcher,
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
- use_precomputed_msas=FLAGS.use_precomputed_msas)
+ use_precomputed_msas=FLAGS.use_precomputed_msas,
+ n_parallel_msa=FLAGS.n_parallel_msa)

if run_multimer_system:
num_predictions_per_model = FLAGS.num_multimer_predictions_per_model
--
2.39.3