Skip to content

Commit 4ad8323

Browse files
committed
Carvana dataset loader
1 parent 84f8392 commit 4ad8323

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

utils/dataset.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010

1111
class BasicDataset(Dataset):
12-
def __init__(self, imgs_dir, masks_dir, scale=1):
12+
def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
1313
self.imgs_dir = imgs_dir
1414
self.masks_dir = masks_dir
1515
self.scale = scale
16+
self.mask_suffix = mask_suffix
1617
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
1718

1819
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
@@ -43,7 +44,7 @@ def preprocess(cls, pil_img, scale):
4344

4445
def __getitem__(self, i):
4546
idx = self.ids[i]
46-
mask_file = glob(self.masks_dir + idx + '.*')
47+
mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
4748
img_file = glob(self.imgs_dir + idx + '.*')
4849

4950
assert len(mask_file) == 1, \
@@ -63,3 +64,8 @@ def __getitem__(self, i):
6364
'image': torch.from_numpy(img).type(torch.FloatTensor),
6465
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
6566
}
67+
68+
69+
class CarvanaDataset(BasicDataset):
70+
def __init__(self, imgs_dir, masks_dir, scale=1):
71+
super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')

0 commit comments

Comments
 (0)