From 58c9fc6c3fb656adf222b10aa980fd906db4ab58 Mon Sep 17 00:00:00 2001 From: hypnopump Date: Thu, 22 Jul 2021 21:20:15 +0200 Subject: [PATCH] make sure tests pass --- mp_nerf/ml_utils.py | 8 ++++---- tests/test_ml_utils.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mp_nerf/ml_utils.py b/mp_nerf/ml_utils.py index bd69041..584f5aa 100644 --- a/mp_nerf/ml_utils.py +++ b/mp_nerf/ml_utils.py @@ -56,7 +56,7 @@ def rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_fe amb_idxs = np.array(pairs["indexs"]).flatten().tolist() idxs = torch.tensor([ k for k,s in enumerate(seq) if s==aa and \ - idx in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist() ) + k in set( torch.nonzero(aux_cloud_mask[i, :, amb_idxs].sum(dim=-1)).tolist()[0] ) ]).long() # check if any AAs matching if idxs.shape[0] == 0: @@ -115,7 +115,7 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None, Outputs: (B, N_atoms) """ fape_store = [] - if l_func is not None: + if l_func is None: l_func = lambda x,y,eps=1e-7,sup=max_val: (((x-y)**2).sum(dim=-1) + eps).sqrt() # for chain for s in range(pred_coords.shape[0]): @@ -144,8 +144,8 @@ def fape_torch(pred_coords, true_coords, max_val=10., l_func=None, # measure errors - for residue for i,rot_mat in enumerate(rot_mats): - fape_store[s] += l1( pred_center[s][mask_center[s]] @ rot_mat, - true_center[s][mask_center[s]] + fape_store[s] += l_func( pred_center[s][mask_center[s]] @ rot_mat, + true_center[s][mask_center[s]] ).clamp(0, max_val) fape_store[s] /= rot_mats.shape[0] diff --git a/tests/test_ml_utils.py b/tests/test_ml_utils.py index 99fd3bb..999f4a9 100644 --- a/tests/test_ml_utils.py +++ b/tests/test_ml_utils.py @@ -26,9 +26,11 @@ def test_rename_symmetric_atoms(): seq_list = ["AGHHKLHRTVNMSTIL"] pred_coors = torch.randn(1, 16, 14, 3) true_coors = torch.randn(1, 16, 14, 3) - cloud_mask = scn_cloud_mask(seq_list[0]).unsqueeze(-1) + cloud_mask = scn_cloud_mask(seq_list[0]).unsqueeze(0) pred_feats = torch.randn(1, 16, 14, 16) + print(cloud_mask.shape) + renamed = rename_symmetric_atoms(pred_coors, true_coors, seq_list, cloud_mask, pred_feats=pred_feats) assert renamed[0].shape == pred_coors.shape and renamed[1].shape == pred_feats.shape, "Shapes don't match"