-
Notifications
You must be signed in to change notification settings - Fork 3
/
datasets.py
130 lines (109 loc) · 4.34 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.utils.data as data
import os
from PIL import Image
_NUM_CLASSES = {
'imagenet': 1000,
'indoor': 67,
}
def default_loader(path):
return Image.open(path).convert('RGB')
def default_flist_reader(flist):
"""
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
"""
imlist = []
with open(flist, 'r') as rf:
for line in rf.readlines():
impath, imlabel, imindex = line.strip().split()
imlist.append((impath, int(imlabel), int(imindex)))
return imlist
class ImageFilelist(data.Dataset):
def __init__(self, flist, transform=None, target_transform=None,
flist_reader=default_flist_reader, loader=default_loader):
self.imlist = flist_reader(flist)
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
impath, target, index = self.imlist[index]
img = self.loader(impath)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target, index
def __len__(self):
return len(self.imlist)
def imagenet(batch_size, train=True, val=True, **kwargs):
train_list = 'imagenet_gt_tr.txt'
val_list = 'imagenet_gt_val.txt'
num_workers = kwargs.setdefault('num_workers', 1)
kwargs.pop('input_size', None)
print("Building data loader with {} workers".format(num_workers))
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
ImageFilelist(
flist=train_list,
transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])),
batch_size=batch_size, shuffle=True, **kwargs)
print("Training data size: {}".format(len(train_loader.dataset)))
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
ImageFilelist(
flist=val_list,
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])),
batch_size=batch_size, shuffle=False, **kwargs)
print("Testing data size: {}".format(len(test_loader.dataset)))
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds
def indoor(batch_size, train=True, val=True, **kwargs):
train_list = 'indoor_gt_tr.txt'
val_list = 'indoor_gt_val.txt'
num_workers = kwargs.setdefault('num_workers', 1)
kwargs.pop('input_size', None)
print("Building data loader with {} workers".format(num_workers))
ds = []
if train:
train_loader = torch.utils.data.DataLoader(
ImageFilelist(
flist=train_list,
transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])),
batch_size=batch_size, shuffle=True, **kwargs)
print("Training data size: {}".format(len(train_loader.dataset)))
ds.append(train_loader)
if val:
test_loader = torch.utils.data.DataLoader(
ImageFilelist(
flist=val_list,
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])),
batch_size=batch_size, shuffle=False, **kwargs)
print("Testing data size: {}".format(len(test_loader.dataset)))
ds.append(test_loader)
ds = ds[0] if len(ds) == 1 else ds
return ds