Skip to content

Commit

Permalink
last fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hypnopump committed May 31, 2021
1 parent f05595a commit 221a21a
Show file tree
Hide file tree
Showing 4 changed files with 29,704 additions and 1,756 deletions.
24 changes: 12 additions & 12 deletions mp_nerf/proteins.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def scn_angle_mask(seq, angles=None, device=None):
# get masks
theta_mask = torch.tensor([SUPREME_INFO[aa]['theta_mask'] for aa in seq], dtype=precise).to(device)
torsion_mask = torch.tensor([SUPREME_INFO[aa][torsion_mask_use] for aa in seq], dtype=precise).to(device)
# O placement - same as in sidechainnet
# =O placement - same as in sidechainnet
theta_mask[:, 3] = BB_BUILD_INFO["BONDANGS"]["ca-c-o"]
# https://github.com/jonathanking/sidechainnet/blob/master/sidechainnet/structure/StructureBuilder.py#L313der.py#L313
torsion_mask[:, 3] = angles[:, 1] - np.pi if angles is not None else -2.406 # from the xtension
Expand All @@ -66,7 +66,7 @@ def scn_angle_mask(seq, angles=None, device=None):
torsion_mask[:, 0] = angles[:, 1] # n determined by psi of previous
torsion_mask[1:, 1] = angles[:-1, 2] # ca determined by omega of previous
torsion_mask[:, 2] = angles[:, 0] # c determined by phi


# add torsions to sidechains
to_fill = torsion_mask != torsion_mask # "p" fill with passed values
Expand Down Expand Up @@ -286,13 +286,13 @@ def protein_fold(cloud_mask, point_ref_mask, angles_mask, bond_mask,

# to place C-beta, we need the carbons from prev res - not available for the 1st res
if i == 4:
# for 1st residue, use position of the second residue's N
first_next_n = coords[1, :1] # 1, 3
# the c requested is from the previous residue - offset boolean mask by one
# can't be done with slicing bc glycines are inside chain (dont have cb)
main_c_prev_idxs = coords[(torch.nonzero(level_mask).view(-1) - 1), idx_a][1:] # (L-1), 3
# concat coords
coords_a = torch.cat([first_next_n, main_c_prev_idxs])
coords_a = coords[(level_mask.nonzero().view(-1) - 1), idx_a] # (L-1), 3
# if first residue is not glycine,
# for 1st residue, use position of the second residue's N (1,3)
if level_mask[0].item():
coords_a[0] = coords[1, 1]
else:
coords_a = coords[level_mask, idx_a]

Expand Down Expand Up @@ -335,13 +335,13 @@ def sidechain_fold(wrapper, cloud_mask, point_ref_mask, angles_mask, bond_mask,

# to place C-beta, we need the carbons from prev res - not available for the 1st res
if i == 4:
# for 1st residue, use position of the second residue's N
first_next_n = wrapper[1, :1] # 1, 3
# the c requested is from the previous residue - offset boolean mask by one
# can't be done with slicing bc glycines are inside chain (dont have cb)
main_c_prev_idxs = wrapper[(level_mask.nonzero().view(-1) - 1), idx_a][1:] # (L-1), 3
# concat coords
coords_a = torch.cat([first_next_n, main_c_prev_idxs])
coords_a = wrapper[(level_mask.nonzero().view(-1) - 1), idx_a] # (L-1), 3
# if first residue is not glycine,
# for 1st residue, use position of the second residue's N (1,3)
if level_mask[0].item():
coords_a[0] = wrapper[1, 1]
else:
coords_a = wrapper[level_mask, idx_a]

Expand Down
27 changes: 15 additions & 12 deletions mp_nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150, verbose=Tru
Outputs: (cleaned, without padding)
(seq_str, int_seq, coords, angles, padding_seq, mask, pid)
"""
for b,batch in enumerate(dataloader_['train']):
# try for breaking from 2 loops at once
try:
while True:
for b,batch in enumerate(dataloader_['train']):
for i in range(batch.int_seqs.shape[0]):
# strip padding padding
# strip padding - matching angles to string means
# only accepting prots with no missing residues (angles would be 0)
padding_seq = (batch.int_seqs[i] == 20).sum().item()
padding_angles = (torch.abs(batch.angs[i]).sum(dim=-1) == 0).long().sum().item()

Expand All @@ -37,17 +37,20 @@ def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150, verbose=Tru
mask = batch.msks[i][:-padding_seq or None]
coords = batch.crds[i][:-padding_seq*14 or None]

print("stopping at sequence of length", real_len)
raise StopIteration
if verbose:
print("stopping at sequence of length", real_len)
return seq, int_seq, coords, angles, padding_seq, mask, batch.pids[i]
else:
if verbose:
print("found a seq of length:", batch.int_seqs[i].shape,
"but oustide the threshold:", min_len, max_len)
else:
# print("found a seq of length:", len(seq),
# "but oustide the threshold:", min_len, max_len)
if verbose:
print("paddings not matching", padding_seq, padding_angles)
pass

except StopIteration:
break

return seq, int_seq, coords, angles, padding_seq, mask, batch.pids[i]
return None



######################
Expand Down
Loading

0 comments on commit 221a21a

Please sign in to comment.