From 8b9f615a5ca1c6824aeebb8cee4ecf96efcb44b1 Mon Sep 17 00:00:00 2001 From: Jonathan King Date: Fri, 2 Sep 2022 16:43:16 -0400 Subject: [PATCH] Fix for #48. There were two issues. 1) The uploaded debug dataset included nans (this is for a future release, should contain 0 padding). 2) The structure builder creates tensor coordinates which must be transformed to numpy for viewing. --- sidechainnet/dataloaders/collate.py | 35 +++++++++++++--------- sidechainnet/structure/StructureBuilder.py | 8 +++-- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/sidechainnet/dataloaders/collate.py b/sidechainnet/dataloaders/collate.py index b3b9980..85fe21c 100644 --- a/sidechainnet/dataloaders/collate.py +++ b/sidechainnet/dataloaders/collate.py @@ -11,11 +11,10 @@ from sidechainnet.structure.build_info import NUM_COORDS_PER_RES from sidechainnet.utils.download import MAX_SEQ_LEN - -Batch = collections.namedtuple("Batch", - "pids seqs msks evos secs angs " - "crds int_seqs seq_evo_sec resolutions is_modified " - "lengths str_seqs") +Batch = collections.namedtuple( + "Batch", "pids seqs msks evos secs angs " + "crds int_seqs seq_evo_sec resolutions is_modified " + "lengths str_seqs") def get_collate_fn(aggregate_input, seqs_as_onehot=None): @@ -60,7 +59,8 @@ def collate_fn(insts): """ # Instead of working with a list of tuples, we extract out each category of info # so it can be padded and re-provided to the user. - pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods, str_seqs = list(zip(*insts)) + pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods, str_seqs = list( + zip(*insts)) lengths = tuple(len(s) for s in sequences) max_batch_len = max(lengths) @@ -248,13 +248,20 @@ def prepare_dataloaders(data, downsample=train_eval_downsample)) valid_loaders = {} - for vsplit in VALID_SPLITS: - valid_loader = torch.utils.data.DataLoader(ProteinDataset( - data[vsplit], vsplit, data['settings'], data['date']), - num_workers=num_workers, - batch_size=batch_size, - collate_fn=collate_fn) - valid_loaders[vsplit] = valid_loader + valid_splits = [splitname for splitname in data.keys() if "valid" in splitname] + for vsplit in valid_splits: + try: + valid_loader = torch.utils.data.DataLoader(ProteinDataset( + data[vsplit], + vsplit, + data['settings'], + data['date']), + num_workers=1, + batch_size=batch_size, + collate_fn=collate_fn) + valid_loaders[vsplit] = valid_loader + except KeyError: + pass test_loader = torch.utils.data.DataLoader(ProteinDataset(data['test'], 'test', data['settings'], @@ -268,7 +275,7 @@ def prepare_dataloaders(data, 'train-eval': train_eval_loader, 'test': test_loader } - dataloaders.update({vsplit: valid_loaders[vsplit] for vsplit in VALID_SPLITS}) + dataloaders.update(valid_loaders) return dataloaders diff --git a/sidechainnet/structure/StructureBuilder.py b/sidechainnet/structure/StructureBuilder.py index 0ecc81c..b6c0998 100644 --- a/sidechainnet/structure/StructureBuilder.py +++ b/sidechainnet/structure/StructureBuilder.py @@ -167,10 +167,14 @@ def build(self): self.coords += res.build() prev_res = res - if self.data_type == 'torch': + if self.data_type == 'numpy' and torch.is_tensor(self.coords[0]): self.coords = torch.stack(self.coords) - else: + self.coords = self.coords.detach().numpy() + elif self.data_type == 'numpy' and isinstance(self.coords[0], np.ndarray): self.coords = np.stack(self.coords) + else: + self.coords = torch.stack(self.coords) + self.data_type = 'torch' return self.coords