Skip to content

Commit

Permalink
Fix for #48.
Browse files Browse the repository at this point in the history
There were two issues. 1) The uploaded debug dataset included
 nans (this is for a future release, should contain 0 padding).
 2) The structure builder creates tensor coordinates which
 must be transformed to numpy for viewing.
  • Loading branch information
jonathanking committed Sep 2, 2022
1 parent ebdafdc commit 8b9f615
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
35 changes: 21 additions & 14 deletions sidechainnet/dataloaders/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@
from sidechainnet.structure.build_info import NUM_COORDS_PER_RES
from sidechainnet.utils.download import MAX_SEQ_LEN


Batch = collections.namedtuple("Batch",
"pids seqs msks evos secs angs "
"crds int_seqs seq_evo_sec resolutions is_modified "
"lengths str_seqs")
Batch = collections.namedtuple(
"Batch", "pids seqs msks evos secs angs "
"crds int_seqs seq_evo_sec resolutions is_modified "
"lengths str_seqs")


def get_collate_fn(aggregate_input, seqs_as_onehot=None):
Expand Down Expand Up @@ -60,7 +59,8 @@ def collate_fn(insts):
"""
# Instead of working with a list of tuples, we extract out each category of info
# so it can be padded and re-provided to the user.
pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods, str_seqs = list(zip(*insts))
pnids, sequences, masks, pssms, secs, angles, coords, resolutions, mods, str_seqs = list(
zip(*insts))
lengths = tuple(len(s) for s in sequences)
max_batch_len = max(lengths)

Expand Down Expand Up @@ -248,13 +248,20 @@ def prepare_dataloaders(data,
downsample=train_eval_downsample))

valid_loaders = {}
for vsplit in VALID_SPLITS:
valid_loader = torch.utils.data.DataLoader(ProteinDataset(
data[vsplit], vsplit, data['settings'], data['date']),
num_workers=num_workers,
batch_size=batch_size,
collate_fn=collate_fn)
valid_loaders[vsplit] = valid_loader
valid_splits = [splitname for splitname in data.keys() if "valid" in splitname]
for vsplit in valid_splits:
try:
valid_loader = torch.utils.data.DataLoader(ProteinDataset(
data[vsplit],
vsplit,
data['settings'],
data['date']),
num_workers=1,
batch_size=batch_size,
collate_fn=collate_fn)
valid_loaders[vsplit] = valid_loader
except KeyError:
pass

test_loader = torch.utils.data.DataLoader(ProteinDataset(data['test'], 'test',
data['settings'],
Expand All @@ -268,7 +275,7 @@ def prepare_dataloaders(data,
'train-eval': train_eval_loader,
'test': test_loader
}
dataloaders.update({vsplit: valid_loaders[vsplit] for vsplit in VALID_SPLITS})
dataloaders.update(valid_loaders)

return dataloaders

Expand Down
8 changes: 6 additions & 2 deletions sidechainnet/structure/StructureBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,14 @@ def build(self):
self.coords += res.build()
prev_res = res

if self.data_type == 'torch':
if self.data_type == 'numpy' and torch.is_tensor(self.coords[0]):
self.coords = torch.stack(self.coords)
else:
self.coords = self.coords.detach().numpy()
elif self.data_type == 'numpy' and isinstance(self.coords[0], np.ndarray):
self.coords = np.stack(self.coords)
else:
self.coords = torch.stack(self.coords)
self.data_type = 'torch'

return self.coords

Expand Down

0 comments on commit 8b9f615

Please sign in to comment.