Skip to content

Commit

Permalink
make sure tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
hypnopump committed Jul 22, 2021
1 parent 2ce9300 commit 58c9fc6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 4 additions & 4 deletions mp_nerf/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_fe
amb_idxs = np.array(pairs["indexs"]).flatten().tolist()
idxs = torch.tensor([
k for k,s in enumerate(seq) if s==aa and \
idx in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist() )
k in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist()[0] )
]).long()
# check if any AAs matching
if idxs.shape[0] == 0:
Expand Down Expand Up @@ -115,7 +115,7 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None,
Outputs: (B, N_atoms)
"""
fape_store = []
if l_func is not None:
if l_func is None:
l_func = lambda x,y,eps=1e-7,sup=max_val: (((x-y)**2).sum(dim=-1) + eps).sqrt()
# for chain
for s in range(pred_coords.shape[0]):
Expand Down Expand Up @@ -144,8 +144,8 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None,

# measure errors - for residue
for i,rot_mat in enumerate(rot_mats):
fape_store[s] += l1( pred_center[s][mask_center[s]] @ rot_mat,
true_center[s][mask_center[s]]
fape_store[s] += l_func( pred_center[s][mask_center[s]] @ rot_mat,
true_center[s][mask_center[s]]
).clamp(0, max_val)
fape_store[s] /= rot_mats.shape[0]

Expand Down
4 changes: 3 additions & 1 deletion tests/test_ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def test_rename_symmetric_atoms():
seq_list = ["AGHHKLHRTVNMSTIL"]
pred_coors = torch.randn(1, 16, 14, 3)
true_coors = torch.randn(1, 16, 14, 3)
cloud_mask = scn_cloud_mask(seq_list[0]).unsqueeze(-1)
cloud_mask = scn_cloud_mask(seq_list[0]).unsqueeze(0)
pred_feats = torch.randn(1, 16, 14, 16)

print(cloud_mask.shape)

renamed = rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_feats=pred_feats)
assert renamed[0].shape == pred_coors.shape and renamed[1].shape == pred_feats.shape, "Shapes don't match"

Expand Down

0 comments on commit 58c9fc6

Please sign in to comment.