9
9
10
10
11
11
class BasicDataset (Dataset ):
12
- def __init__ (self , imgs_dir , masks_dir , scale = 1 ):
12
+ def __init__ (self , imgs_dir , masks_dir , scale = 1 , mask_suffix = '' ):
13
13
self .imgs_dir = imgs_dir
14
14
self .masks_dir = masks_dir
15
15
self .scale = scale
16
+ self .mask_suffix = mask_suffix
16
17
assert 0 < scale <= 1 , 'Scale must be between 0 and 1'
17
18
18
19
self .ids = [splitext (file )[0 ] for file in listdir (imgs_dir )
@@ -43,7 +44,7 @@ def preprocess(cls, pil_img, scale):
43
44
44
45
def __getitem__ (self , i ):
45
46
idx = self .ids [i ]
46
- mask_file = glob (self .masks_dir + idx + '.*' )
47
+ mask_file = glob (self .masks_dir + idx + self . mask_suffix + '.*' )
47
48
img_file = glob (self .imgs_dir + idx + '.*' )
48
49
49
50
assert len (mask_file ) == 1 , \
@@ -63,3 +64,8 @@ def __getitem__(self, i):
63
64
'image' : torch .from_numpy (img ).type (torch .FloatTensor ),
64
65
'mask' : torch .from_numpy (mask ).type (torch .FloatTensor )
65
66
}
67
+
68
+
69
+ class CarvanaDataset (BasicDataset ):
70
+ def __init__ (self , imgs_dir , masks_dir , scale = 1 ):
71
+ super ().__init__ (imgs_dir , masks_dir , scale , mask_suffix = '_mask' )
0 commit comments