Skip to content

Commit

Permalink
for release
Browse files Browse the repository at this point in the history
  • Loading branch information
hypnopump committed Jun 12, 2021
1 parent b471624 commit cb4315e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
6 changes: 3 additions & 3 deletions mp_nerf/kb_proteins.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,16 +868,16 @@ def make_idx_mask(aa):
###################
##### GETTERS #####
###################
INDEX2AAS = "ACDEFGHIKLMNPQRSTVWY_"
AAS2INDEX = {aa:i for i,aa in enumerate(INDEX2AAS)}
SUPREME_INFO = {k: {"cloud_mask": make_cloud_mask(k),
"bond_mask": make_bond_mask(k),
"theta_mask": make_theta_mask(k),
"torsion_mask": make_torsion_mask(k),
"torsion_mask_filled": make_torsion_mask(k, fill=True),
"idx_mask": make_idx_mask(k),
}
for k in "ARNDCQEGHILKMFPSTWYV_"}


for k in INDEX2AAS}



Expand Down
15 changes: 8 additions & 7 deletions mp_nerf/proteins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,24 @@
from mp_nerf.kb_proteins import *


def scn_cloud_mask(seq, coords=None):
def scn_cloud_mask(seq, coords=None, strict=False):
""" Gets the boolean mask atom positions (not all aas have same atoms).
Inputs:
* seqs: (length) iterable of 1-letter aa codes of a protein
* coords: optional .(batch, lc, 3). sidechainnet coords.
returns the true mask (solves potential atoms that might not be provided)
* strict: bool. whther to discard the next points after a missing one
Outputs: (length, 14) boolean mask
"""
if coords is not None:
start = (( rearrange(coords, 'b (l c) d -> b l c d', c=14) != 0 ).sum(dim=-1) != 0).float()
# if a point is 0, the following are 0s as well
for b in range(start.shape[0]):
for pos in range(start.shape[1]):
for chain in range(start.shape[2]):
if start[b, pos, chain].item() == 0.:
start[b, pos, chain:] *= 0.
if strict:
for b in range(start.shape[0]):
for pos in range(start.shape[1]):
for chain in range(start.shape[2]):
if start[b, pos, chain].item() == 0:
start[b, pos, chain:] *= 0
return start
return torch.tensor([SUPREME_INFO[aa]['cloud_mask'] for aa in seq])

Expand Down Expand Up @@ -311,7 +313,6 @@ def sidechain_fold(wrapper, cloud_mask, point_ref_mask, angles_mask, bond_mask,
Inputs:
* wrapper: (L, 14, 3). coords container with backbone ([:, :3]) and optionally
c_beta ([:, 4])
* seqs: iterable (string, list...) of aas (1 letter corde)
* cloud_mask: (L, 14) mask of points that should be converted to coords
* point_ref_mask: (3, L, 11) maps point (except n-ca-c) to idxs of
previous 3 points in the coords array
Expand Down
2 changes: 0 additions & 2 deletions mp_nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,9 @@ def get_prot(dataloader_=None, vocab_=None, min_len=80, max_len=150, verbose=Tru
if verbose:
print("paddings not matching", padding_seq, padding_angles)
pass

return None



######################
## structural utils ##
######################
Expand Down

0 comments on commit cb4315e

Please sign in to comment.