diff --git a/scripts/rgn2_predict_fold.py b/scripts/rgn2_predict_fold.py index 9d5567f..90d9713 100644 --- a/scripts/rgn2_predict_fold.py +++ b/scripts/rgn2_predict_fold.py @@ -1,35 +1,29 @@ # Author: Eirc Alcaide (@hypnopump) -import os -import re +from pathlib import Path import json -import numpy as np import torch # process import argparse -import joblib -from tqdm import tqdm # custom -import esm import sidechainnet -import mp_nerf from rgn2_replica import * -from rgn2_replica.embedders import * -from rgn2_replica.rgn2_refine import * +from rgn2_replica.embedders import get_embedder from rgn2_replica.rgn2_utils import seqs_from_fasta from rgn2_replica.rgn2_trainers import infer_from_seqs -if __name__ == "__main__": - # !python redictor.py --input proteins.fasta --model ../rgn2_models/baseline_run@_125K.pt --device 2 +def parse_arguments(): + # !python redictor.py --input proteins.fasta --model ../rgn2_models/baseline_run@_125K.pt --device 2 parser = argparse.ArgumentParser('Predict with RGN2 model') + input_group = parser.add_mutually_exclusive_group(required=True) + # inputs parser.add_argument("--input", help="FASTA or MultiFASTA with protein sequences to predict") parser.add_argument("--batch_size", type=int, default=1, help="batch size for prediction") - # model - parser.add_argument("--embedder_model", type=str, help="esm1b") - parser.add_argument("--model", type=str, help="Model file for prediction") + input_group.add_argument("--pdb_input", type=str, help="PDB file for refinement") # model params - same as training + input_group.add_argument("--model", type=str, help="Model file for prediction") parser.add_argument("--embedder_model", help="Embedding model to use", default='esm1b') parser.add_argument("--num_layers", help="num rnn layers", type=int, default=2) parser.add_argument("--emb_dim", help="embedding dimension", type=int, default=1280) @@ -47,10 +41,12 @@ parser.add_argument("--rosetta_relax", type=int, default=0, help="relax output with Rosetta. 0 for no relax.") parser.add_argument("--coord_constraint", type=float, default=1.0, help="constraint for Rosetta relax. higher=stricter.") parser.add_argument("--device", help="Device ('cpu', cuda:0', ...)", type=str, required=True) + # outputs - parser.add_argument("--output_path", type=str, default=None, # prot_id.fasta -> prot_id_0.fasta, + parser.add_argument("--output_path", type=str, default=None, #  prot_id.fasta -> prot_id_0.fasta, help="path for output .pdb files. Defaults to input name + seq num") # refiner params + parser.add_argument("--af2_refine", type=bool, default=False, help="refine output with AlphaFold2") parser.add_argument("--refiner_args", help="args for refiner module", type=json.loads, default={}) parser.add_argument("--seed", help="Random seed", default=101) @@ -60,14 +56,15 @@ args.refiner_args = dict(args.refiner_args) # mod parsed args - if args.output_path is None: + if args.output_path is None: args.output_path = args.input.replace(".fasta", "_") - # get sequences - seq_list, seq_names = seqs_from_fasta(args.input, names=True) + return args + - # predict structures +def load_model(args): config = args + config.mlp_hidden = [128, 4 if args.angularize == 0 else args.angularize] # 4 # 64 set_seed(config.seed) model = RGN2_Naive(layers=config.num_layers, emb_dim=config.emb_dim+4, @@ -79,66 +76,134 @@ input_dropout=config.input_dropout, angularize=config.angularize, refiner_args=config.refiner_args, - ).to(device) + ).to(args.device) model.load_my_state_dict(torch.load(args.model, map_location=args.device)) model = model.eval() + # # Load ESM-1b model embedder = get_embedder(args, args.device) + return model.eval(), embedder + + +def predict(model, embedder, seq_list, seq_names, args): # batch wrapper pred_dict = {} num_batches = len(seq_list) // args.batch_size + \ - int(bool(len(seq_list) % args.batch_size)) + int(bool(len(seq_list) % args.batch_size)) - for i in range( num_batches ): + for i in range(num_batches): aux = infer_from_seqs( - seq_list[args.batch_size*i : args.batch_size*(i+1)], - model = model, - embedder = embedder, - recycle_func=lambda x: int(args.recycle), + seq_list[args.batch_size*i: args.batch_size*(i+1)], + model=model, + embedder=embedder, + recycle_func=lambda x: int(args.num_recycles_pred), device=args.device ) - for k,v in aux.items(): - try: pred_dict[k] += v - except KeyError: pred_dict[k] = v + for k, v in aux.items(): + try: + pred_dict[k] += v + except KeyError: + pred_dict[k] = v # save structures out_files = [] - for i, seq in enumerate(seq_list): + for i, seq in enumerate(seq_list): struct_pred = sidechainnet.StructureBuilder( - pred_dict["int_seq"][i].cpu(), - crd = pred_dict["coords"][i].reshape(-1, 3).cpu() - ) - out_files.append( args.output_path+str(i)+"_"+seq_names[i]+".pdb" ) - struct_pred.to_pdb( out_files[-1] ) - + pred_dict["int_seq"][i].cpu(), + crd=pred_dict["coords"][i].reshape(-1, 3).cpu() + ) + out_files.append(args.output_path+str(i)+"_"+seq_names[i]+".pdb") + struct_pred.to_pdb(out_files[-1]) + print("Saved", out_files[-1]) + return pred_dict, out_files + + +def refine(seq_list, pdb_files, args): # refine structs - if args.rosetta_refine: - from typing import Optional - import pyrosetta - - for i, seq in enumerate(seq_list): - # only refine - if args.rosetta_relax == 0: - quick_refine( - in_pdb = out_files[i], - out_pdb = out_files[i][:-4]+"_refined.pdb", - min_iter = args.rosetta_refine - ) - # refine and relax - else: - relax_refine( - out_files[i], - out_pdb=out_files[i][:-4]+"_refined_relaxed.pdb", - min_iter = args.rosetta_refine, - relax_iter = args.rosetta_relax, - coord_constraint = args.coord_constraint, - ) - print(out_files[i], "was refined successfully") - - print("All tasks done. Exiting...") + if args.rosetta_refine: + result = rosetta_refine(seq_list, pdb_files, args) + elif args.af2_refine: + result = af2_refine(pdb_files, args) + else: + result = None + + return result + + +def rosetta_refine(seq_list, pdb_files, args): + from rgn2_replica.rgn2_refine import quick_refine, relax_refine + + for i, seq in enumerate(seq_list): + # only refine + if args.rosetta_relax == 0: + quick_refine( + in_pdb=pdb_files[i], + out_pdb=pdb_files[i][:-4]+"_refined.pdb", + min_iter=args.rosetta_refine + ) + # refine and relax + else: + relax_refine( + pdb_files[i], + out_pdb=pdb_files[i][:-4]+"_refined_relaxed.pdb", + min_iter=args.rosetta_refine, + relax_iter=args.rosetta_relax, + coord_constraint=args.coord_constraint, + ) + print(pdb_files[i], "was refined successfully") + + +def af2_refine(pdb_files, args): + from alphafold.relax import relax + from alphafold.common import protein + + amber_relaxer = relax.AmberRelaxation( + max_iterations=0, + tolerance=2.39, + stiffness=10.0, + exclude_residues=[], + max_outer_iterations=20) + + out_dir = Path(args.output_path) + relaxed_pdbs_files = [] + for pdb_file in pdb_files: + pdb_file = Path(pdb_file) + pdb_str = pdb_file.read_text() + prot = protein.from_pdb_string(pdb_str) + min_pdb, debug_data, violations = amber_relaxer.process(prot=prot) + + dest_file = out_dir / pdb_file.name + dest_file.write_text(min_pdb) + + relaxed_pdbs_files.append(dest_file) + + return relaxed_pdbs_files + + +def predict_refine(): + args = parse_arguments() + + # get sequences + if args.input is not None: + seq_list, seq_names = seqs_from_fasta(args.input, names=True) + else: + seq_list, seq_names = None, None + + if args.model is not None: + model, embedder = load_model(args) + _, pdb_files = predict(model, embedder, seq_list, seq_names, args) + else: + pdb_files = [args.pdb_input] + + relaxed_pdbs = refine(seq_list, pdb_files, args) + + print('result files:') + print(relaxed_pdbs) +if __name__ == "__main__": + predict_refine()