Skip to content

Commit d2d387c

Browse files
authored
Add files via upload
1 parent 987eede commit d2d387c

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import argparse
2+
import os
3+
from os import listdir
4+
from os.path import join
5+
6+
from PIL import Image
7+
from torch.utils.data.dataset import Dataset
8+
from torchvision.transforms import Compose, CenterCrop, Scale
9+
from tqdm import tqdm
10+
11+
12+
def is_image_file(filename):
13+
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG'])
14+
15+
16+
def calculate_valid_crop_size(crop_size, upscale_factor):
17+
return crop_size - (crop_size % upscale_factor)
18+
19+
20+
def input_transform(crop_size, upscale_factor):
21+
return Compose([
22+
CenterCrop(crop_size),
23+
Scale(crop_size // upscale_factor, interpolation=Image.BICUBIC)
24+
])
25+
26+
27+
def target_transform(crop_size):
28+
return Compose([
29+
CenterCrop(crop_size)
30+
])
31+
32+
33+
class DatasetFromFolder(Dataset):
34+
def __init__(self, dataset_dir, upscale_factor, input_transform=None, target_transform=None):
35+
super(DatasetFromFolder, self).__init__()
36+
self.image_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/data'
37+
self.target_dir = dataset_dir + '/SRF_' + str(upscale_factor) + '/target'
38+
self.image_filenames = [join(self.image_dir, x) for x in listdir(self.image_dir) if is_image_file(x)]
39+
self.target_filenames = [join(self.target_dir, x) for x in listdir(self.target_dir) if is_image_file(x)]
40+
self.input_transform = input_transform
41+
self.target_transform = target_transform
42+
43+
def __getitem__(self, index):
44+
image, _, _ = Image.open(self.image_filenames[index]).convert('YCbCr').split()
45+
target, _, _ = Image.open(self.target_filenames[index]).convert('YCbCr').split()
46+
if self.input_transform:
47+
image = self.input_transform(image)
48+
if self.target_transform:
49+
target = self.target_transform(target)
50+
51+
return image, target
52+
53+
def __len__(self):
54+
return len(self.image_filenames)
55+
56+
57+
def generate_dataset(data_type, upscale_factor):
58+
images_name = [x for x in listdir('D:\\ALL_DataSet\\VOCdevkit\\VOC2012\\VOC\\' + data_type) if is_image_file(x)]
59+
crop_size = calculate_valid_crop_size(256, upscale_factor)
60+
lr_transform = input_transform(crop_size, upscale_factor)
61+
hr_transform = target_transform(crop_size)
62+
63+
root = 'data/' + data_type
64+
if not os.path.exists(root):
65+
os.makedirs(root)
66+
path = root + '/SRF_' + str(upscale_factor)
67+
if not os.path.exists(path):
68+
os.makedirs(path)
69+
image_path = path + '/data'
70+
if not os.path.exists(image_path):
71+
os.makedirs(image_path)
72+
target_path = path + '/target'
73+
if not os.path.exists(target_path):
74+
os.makedirs(target_path)
75+
76+
for image_name in tqdm(images_name, desc='generate ' + data_type + ' dataset with upscale factor = '
77+
+ str(upscale_factor) + ' from VOC2012'):
78+
image = Image.open('D:\\ALL_DataSet\\VOCdevkit\\VOC2012\\VOC\\' + data_type + '/' + image_name)
79+
target = image.copy()
80+
image = lr_transform(image)
81+
target = hr_transform(target)
82+
83+
image.save(image_path + '/' + image_name)
84+
target.save(target_path + '/' + image_name)
85+
86+
87+
if __name__ == "__main__":
88+
parser = argparse.ArgumentParser(description='Generate Super Resolution Dataset')
89+
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
90+
opt = parser.parse_args()
91+
UPSCALE_FACTOR = opt.upscale_factor
92+
93+
generate_dataset(data_type='train', upscale_factor=UPSCALE_FACTOR)
94+
generate_dataset(data_type='val', upscale_factor=UPSCALE_FACTOR)

0 commit comments

Comments
 (0)