From 5e18b3ab5cf9b6e20ec3bd8f0f2a0c02ba9db837 Mon Sep 17 00:00:00 2001 From: fuji Date: Tue, 22 Feb 2022 10:16:10 +0900 Subject: [PATCH 1/4] Add parallel runs of MSA tools --- alphafold/data/pipeline.py | 110 ++++++++++++++++++++++--------------- run_alphafold.py | 5 +- 2 files changed, 70 insertions(+), 45 deletions(-) diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py index c92944fe3..3a3a18d4a 100644 --- a/alphafold/data/pipeline.py +++ b/alphafold/data/pipeline.py @@ -27,6 +27,8 @@ from alphafold.data.tools import jackhmmer import numpy as np +import concurrent.futures + # Internal import (7716). FeatureDict = MutableMapping[str, np.ndarray] @@ -124,7 +126,8 @@ def __init__(self, use_small_bfd: bool, mgnify_max_hits: int = 501, uniref_max_hits: int = 10000, - use_precomputed_msas: bool = False): + use_precomputed_msas: bool = False, + n_parallel_msa: int = 1): """Initializes the data pipeline.""" self._use_small_bfd = use_small_bfd self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( @@ -146,35 +149,13 @@ def __init__(self, self.mgnify_max_hits = mgnify_max_hits self.uniref_max_hits = uniref_max_hits self.use_precomputed_msas = use_precomputed_msas + self.n_parallel_msa = n_parallel_msa - def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: - """Runs alignment tools on the input sequence and creates features.""" - with open(input_fasta_path) as f: - input_fasta_str = f.read() - input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) - if len(input_seqs) != 1: - raise ValueError( - f'More than one input sequence found in {input_fasta_path}.') - input_sequence = input_seqs[0] - input_description = input_descs[0] - num_res = len(input_sequence) - + def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence): uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') jackhmmer_uniref90_result = run_msa_tool( - msa_runner=self.jackhmmer_uniref90_runner, - input_fasta_path=input_fasta_path, - msa_out_path=uniref90_out_path, - msa_format='sto', - use_precomputed_msas=self.use_precomputed_msas, - max_sto_sequences=self.uniref_max_hits) - mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') - jackhmmer_mgnify_result = run_msa_tool( - msa_runner=self.jackhmmer_mgnify_runner, - input_fasta_path=input_fasta_path, - msa_out_path=mgnify_out_path, - msa_format='sto', - use_precomputed_msas=self.use_precomputed_msas, - max_sto_sequences=self.mgnify_max_hits) + self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path, + 'sto', self.use_precomputed_msas) msa_for_templates = jackhmmer_uniref90_result['sto'] msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates) @@ -196,29 +177,70 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: f.write(pdb_templates_result) uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto']) - mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) + uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) pdb_template_hits = self.template_searcher.get_template_hits( output_string=pdb_templates_result, input_sequence=input_sequence) + return uniref90_msa, pdb_template_hits + + def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir): + mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') + jackhmmer_mgnify_result = run_msa_tool( + self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto', + self.use_precomputed_msas) + + mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) + mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) + return mgnify_msa + + def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir): + bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') + hhblits_bfd_uniclust_result = run_msa_tool( + self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path, + 'a3m', self.use_precomputed_msas) + bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) + return bfd_msa + + def jackhmmer_small_bfd_caller(self, input_fasta_path, msa_output_dir): + bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') + jackhmmer_small_bfd_result = run_msa_tool( + msa_runner=self.jackhmmer_small_bfd_runner, + input_fasta_path=input_fasta_path, + msa_out_path=bfd_out_path, + msa_format='sto', + use_precomputed_msas=self.use_precomputed_msas) + bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) + return bfd_msa + + + def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + futures = [] if self._use_small_bfd: - bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') - jackhmmer_small_bfd_result = run_msa_tool( - msa_runner=self.jackhmmer_small_bfd_runner, - input_fasta_path=input_fasta_path, - msa_out_path=bfd_out_path, - msa_format='sto', - use_precomputed_msas=self.use_precomputed_msas) - bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) + with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: + futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) + futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) + futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir)) else: - bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') - hhblits_bfd_uniclust_result = run_msa_tool( - msa_runner=self.hhblits_bfd_uniclust_runner, - input_fasta_path=input_fasta_path, - msa_out_path=bfd_out_path, - msa_format='a3m', - use_precomputed_msas=self.use_precomputed_msas) - bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) + with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: + futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) + futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) + futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir)) + + uniref90_msa, pdb_template_hits = futures[0].result() + mgnify_msa = futures[1].result() + bfd_msa = futures[2].result() templates_result = self.template_featurizer.get_templates( query_sequence=input_sequence, diff --git a/run_alphafold.py b/run_alphafold.py index d87e00e4b..6967d0c9c 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -129,6 +129,8 @@ 'recommended to enable if possible. GPUs must be available' ' if this setting is enabled.') +flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.') + FLAGS = flags.FLAGS MAX_TEMPLATE_HITS = 20 @@ -346,7 +348,8 @@ def main(argv): template_searcher=template_searcher, template_featurizer=template_featurizer, use_small_bfd=use_small_bfd, - use_precomputed_msas=FLAGS.use_precomputed_msas) + use_precomputed_msas=FLAGS.use_precomputed_msas, + n_parallel_msa=FLAGS.n_parallel_msa) if run_multimer_system: num_predictions_per_model = FLAGS.num_multimer_predictions_per_model From 86c77759e10a1e1716be9af90627d1027b576502 Mon Sep 17 00:00:00 2001 From: fuji8 Date: Sun, 12 Feb 2023 03:19:55 +0900 Subject: [PATCH 2/4] fix for v2.3.1 --- alphafold/data/pipeline.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py index eeb37216a..eaea5fed2 100644 --- a/alphafold/data/pipeline.py +++ b/alphafold/data/pipeline.py @@ -154,8 +154,12 @@ def __init__(self, def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence): uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') jackhmmer_uniref90_result = run_msa_tool( - self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path, - 'sto', self.use_precomputed_msas) + msa_runner=self.jackhmmer_uniref90_runner, + input_fasta_path=input_fasta_path, + msa_out_path=uniref90_out_path, + msa_format='sto', + use_precomputed_msas=self.use_precomputed_msas, + max_sto_sequences=self.uniref_max_hits) msa_for_templates = jackhmmer_uniref90_result['sto'] msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates) @@ -186,11 +190,14 @@ def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_outp def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir): mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') jackhmmer_mgnify_result = run_msa_tool( - self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto', - self.use_precomputed_msas) + msa_runner=self.jackhmmer_mgnify_runner, + input_fasta_path=input_fasta_path, + msa_out_path=mgnify_out_path, + msa_format='sto', + use_precomputed_msas=self.use_precomputed_msas, + max_sto_sequences=self.mgnify_max_hits) mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) - mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) return mgnify_msa def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir): @@ -236,7 +243,7 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir)) else: - with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir)) From 6acf53caa74ca1b4900c2ccbc36ea33909965f38 Mon Sep 17 00:00:00 2001 From: fuji8 Date: Sun, 12 Feb 2023 03:28:08 +0900 Subject: [PATCH 3/4] rename func name --- alphafold/data/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py index eaea5fed2..ca8292a53 100644 --- a/alphafold/data/pipeline.py +++ b/alphafold/data/pipeline.py @@ -200,7 +200,7 @@ def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir): mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) return mgnify_msa - def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir): + def hhblits_bfd_uniref_caller(self, input_fasta_path, msa_output_dir): bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m') hhblits_bfd_uniref_result = run_msa_tool( msa_runner=self.hhblits_bfd_uniref_runner, @@ -246,7 +246,7 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) - futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir)) + futures.append(executor.submit(self.hhblits_bfd_uniref_caller, input_fasta_path, msa_output_dir)) uniref90_msa, pdb_template_hits = futures[0].result() mgnify_msa = futures[1].result() From 36ea35491b40b1104e043dcba7eca7463566b69a Mon Sep 17 00:00:00 2001 From: fuji8 Date: Sun, 12 Feb 2023 18:28:28 +0000 Subject: [PATCH 4/4] fix for docker --- docker/run_docker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker/run_docker.py b/docker/run_docker.py index 155d8fe2c..e7e2024ef 100644 --- a/docker/run_docker.py +++ b/docker/run_docker.py @@ -93,6 +93,7 @@ 'will be owned by this user:group. By default, this is the current user. ' 'Valid options are: uid or uid:gid, non-numeric values are not recognised ' 'by Docker unless that user has been created within the container.') +flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.') FLAGS = flags.FLAGS @@ -227,6 +228,7 @@ def main(argv): f'--models_to_relax={FLAGS.models_to_relax}', f'--use_gpu_relax={use_gpu_relax}', '--logtostderr', + f'--n_parallel_msa={FLAGS.n_parallel_msa}' ]) client = docker.from_env()