Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing AlphaFold-based protein refinement #27

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 125 additions & 60 deletions scripts/rgn2_predict_fold.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
# Author: Eirc Alcaide (@hypnopump)
import os
import re
from pathlib import Path
import json
import numpy as np
import torch
# process
import argparse
import joblib
from tqdm import tqdm
# custom
import esm
import sidechainnet
import mp_nerf
from rgn2_replica import *
from rgn2_replica.embedders import *
from rgn2_replica.rgn2_refine import *
from rgn2_replica.embedders import get_embedder
from rgn2_replica.rgn2_utils import seqs_from_fasta
from rgn2_replica.rgn2_trainers import infer_from_seqs


if __name__ == "__main__":
# !python redictor.py --input proteins.fasta --model ../rgn2_models/baseline_run@_125K.pt --device 2
def parse_arguments():
# !python redictor.py --input proteins.fasta --model ../rgn2_models/baseline_run@_125K.pt --device 2
parser = argparse.ArgumentParser('Predict with RGN2 model')
input_group = parser.add_mutually_exclusive_group(required=True)

# inputs
parser.add_argument("--input", help="FASTA or MultiFASTA with protein sequences to predict")
parser.add_argument("--batch_size", type=int, default=1, help="batch size for prediction")
# model
parser.add_argument("--embedder_model", type=str, help="esm1b")
parser.add_argument("--model", type=str, help="Model file for prediction")
input_group.add_argument("--pdb_input", type=str, help="PDB file for refinement")

# model params - same as training
input_group.add_argument("--model", type=str, help="Model file for prediction")
parser.add_argument("--embedder_model", help="Embedding model to use", default='esm1b')
parser.add_argument("--num_layers", help="num rnn layers", type=int, default=2)
parser.add_argument("--emb_dim", help="embedding dimension", type=int, default=1280)
Expand All @@ -47,10 +41,12 @@
parser.add_argument("--rosetta_relax", type=int, default=0, help="relax output with Rosetta. 0 for no relax.")
parser.add_argument("--coord_constraint", type=float, default=1.0, help="constraint for Rosetta relax. higher=stricter.")
parser.add_argument("--device", help="Device ('cpu', cuda:0', ...)", type=str, required=True)

# outputs
parser.add_argument("--output_path", type=str, default=None, # prot_id.fasta -> prot_id_0.fasta,
parser.add_argument("--output_path", type=str, default=None, #  prot_id.fasta -> prot_id_0.fasta,
help="path for output .pdb files. Defaults to input name + seq num")
# refiner params
parser.add_argument("--af2_refine", type=bool, default=False, help="refine output with AlphaFold2")
parser.add_argument("--refiner_args", help="args for refiner module", type=json.loads, default={})
parser.add_argument("--seed", help="Random seed", default=101)

Expand All @@ -60,14 +56,15 @@
args.refiner_args = dict(args.refiner_args)

# mod parsed args
if args.output_path is None:
if args.output_path is None:
args.output_path = args.input.replace(".fasta", "_")

# get sequences
seq_list, seq_names = seqs_from_fasta(args.input, names=True)
return args


# predict structures
def load_model(args):
config = args
config.mlp_hidden = [128, 4 if args.angularize == 0 else args.angularize] # 4 # 64
set_seed(config.seed)
model = RGN2_Naive(layers=config.num_layers,
emb_dim=config.emb_dim+4,
Expand All @@ -79,66 +76,134 @@
input_dropout=config.input_dropout,
angularize=config.angularize,
refiner_args=config.refiner_args,
).to(device)
).to(args.device)

model.load_my_state_dict(torch.load(args.model, map_location=args.device))
model = model.eval()

# # Load ESM-1b model
embedder = get_embedder(args, args.device)

return model.eval(), embedder


def predict(model, embedder, seq_list, seq_names, args):
# batch wrapper
pred_dict = {}
num_batches = len(seq_list) // args.batch_size + \
int(bool(len(seq_list) % args.batch_size))
int(bool(len(seq_list) % args.batch_size))

for i in range( num_batches ):
for i in range(num_batches):
aux = infer_from_seqs(
seq_list[args.batch_size*i : args.batch_size*(i+1)],
model = model,
embedder = embedder,
recycle_func=lambda x: int(args.recycle),
seq_list[args.batch_size*i: args.batch_size*(i+1)],
model=model,
embedder=embedder,
recycle_func=lambda x: int(args.num_recycles_pred),
device=args.device
)
for k,v in aux.items():
try: pred_dict[k] += v
except KeyError: pred_dict[k] = v
for k, v in aux.items():
try:
pred_dict[k] += v
except KeyError:
pred_dict[k] = v

# save structures
out_files = []
for i, seq in enumerate(seq_list):
for i, seq in enumerate(seq_list):
struct_pred = sidechainnet.StructureBuilder(
pred_dict["int_seq"][i].cpu(),
crd = pred_dict["coords"][i].reshape(-1, 3).cpu()
)
out_files.append( args.output_path+str(i)+"_"+seq_names[i]+".pdb" )
struct_pred.to_pdb( out_files[-1] )
pred_dict["int_seq"][i].cpu(),
crd=pred_dict["coords"][i].reshape(-1, 3).cpu()
)
out_files.append(args.output_path+str(i)+"_"+seq_names[i]+".pdb")
struct_pred.to_pdb(out_files[-1])

print("Saved", out_files[-1])

return pred_dict, out_files


def refine(seq_list, pdb_files, args):
# refine structs
if args.rosetta_refine:
from typing import Optional
import pyrosetta

for i, seq in enumerate(seq_list):
# only refine
if args.rosetta_relax == 0:
quick_refine(
in_pdb = out_files[i],
out_pdb = out_files[i][:-4]+"_refined.pdb",
min_iter = args.rosetta_refine
)
# refine and relax
else:
relax_refine(
out_files[i],
out_pdb=out_files[i][:-4]+"_refined_relaxed.pdb",
min_iter = args.rosetta_refine,
relax_iter = args.rosetta_relax,
coord_constraint = args.coord_constraint,
)
print(out_files[i], "was refined successfully")

print("All tasks done. Exiting...")
if args.rosetta_refine:
result = rosetta_refine(seq_list, pdb_files, args)
elif args.af2_refine:
result = af2_refine(pdb_files, args)
else:
result = None

return result


def rosetta_refine(seq_list, pdb_files, args):
from rgn2_replica.rgn2_refine import quick_refine, relax_refine

for i, seq in enumerate(seq_list):
# only refine
if args.rosetta_relax == 0:
quick_refine(
in_pdb=pdb_files[i],
out_pdb=pdb_files[i][:-4]+"_refined.pdb",
min_iter=args.rosetta_refine
)
# refine and relax
else:
relax_refine(
pdb_files[i],
out_pdb=pdb_files[i][:-4]+"_refined_relaxed.pdb",
min_iter=args.rosetta_refine,
relax_iter=args.rosetta_relax,
coord_constraint=args.coord_constraint,
)
print(pdb_files[i], "was refined successfully")


def af2_refine(pdb_files, args):
from alphafold.relax import relax
from alphafold.common import protein

amber_relaxer = relax.AmberRelaxation(
max_iterations=0,
tolerance=2.39,
stiffness=10.0,
exclude_residues=[],
max_outer_iterations=20)

out_dir = Path(args.output_path)
relaxed_pdbs_files = []
for pdb_file in pdb_files:
pdb_file = Path(pdb_file)
pdb_str = pdb_file.read_text()
prot = protein.from_pdb_string(pdb_str)
min_pdb, debug_data, violations = amber_relaxer.process(prot=prot)

dest_file = out_dir / pdb_file.name
dest_file.write_text(min_pdb)

relaxed_pdbs_files.append(dest_file)

return relaxed_pdbs_files


def predict_refine():
args = parse_arguments()

# get sequences
if args.input is not None:
seq_list, seq_names = seqs_from_fasta(args.input, names=True)
else:
seq_list, seq_names = None, None

if args.model is not None:
model, embedder = load_model(args)
_, pdb_files = predict(model, embedder, seq_list, seq_names, args)
else:
pdb_files = [args.pdb_input]

relaxed_pdbs = refine(seq_list, pdb_files, args)

print('result files:')
print(relaxed_pdbs)


if __name__ == "__main__":
predict_refine()