Skip to content

Commit

Permalink
new utils for scn
Browse files Browse the repository at this point in the history
  • Loading branch information
hypnopump committed Jul 22, 2021
1 parent 58c9fc6 commit 1d2bda0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
20 changes: 20 additions & 0 deletions mp_nerf/proteins.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,26 @@ def build_scaffolds_from_scn_angles(seq, angles=None, coords=None, device="auto"
#############################


def modify_angles_mask_with_torsions(seq, angles_mask, torsions):
""" Modifies a torsion mask to include variable torsions.
Inputs:
* seq: (L,) str. FASTA sequence
* angles_mask: (2, L, 14) float tensor of (angles, torsions)
* torsions: (L, 4) float tensor (or (L, 5) if it includes torsion for cb)
Outputs: (2, L, 14) a new angles mask
"""
c_beta = torsions.shape[-1] == 5 # whether c_beta torsion is passed as well
start = 4 if c_beta else 5
# get mask of to-fill values
torsion_mask = torch.tensor([SUPREME_INFO[aa]["torsion_mask"] for aa in seq]).to(torsions.device) # (L, 14)
torsion_mask = torsion_mask != torsion_mask # values that are nan need replace
# undesired outside of margins
torsion_mask[:, :start] = torsion_mask[:, start+torsions.shape[-1]:] = False

angles_mask[1, torsion_mask] = torsions[ torsion_mask[:, start:start+torsions.shape[-1]] ]
return angles_mask


def modify_scaffolds_with_coords(scaffolds, coords):
""" Gets scaffolds and fills in the right data.
Inputs:
Expand Down
12 changes: 11 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,14 @@ def test_nerf_and_dihedral():
# doesnt work because the scn angle was not measured correctly
# so the method corrects that incorrection
assert (mp_nerf_torch(a, b, c, l, theta, chi - np.pi) - torch.tensor([1,0,6])).sum().abs() < 0.1
assert get_dihedral(a, b, c, d).item() == chi
assert get_dihedral(a, b, c, d).item() == chi


def test_modify_angles_mask_with_torsions():
# create inputs
seq = "AGHHKLHRTVNMSTIL"
angles_mask = torch.randn(2, 16, 14)
torsions = torch.ones(16, 4)
# ensure shape
assert modify_angles_mask_with_torsions(seq, angles_mask, torsions).shape == angles_mask.shape, \
"Shapes don't match"

0 comments on commit 1d2bda0

Please sign in to comment.