1
+ import argparse
2
+ import os
3
+ from os import listdir
4
+ from os .path import join
5
+
6
+ from PIL import Image
7
+ from torch .utils .data .dataset import Dataset
8
+ from torchvision .transforms import Compose , CenterCrop , Scale
9
+ from tqdm import tqdm
10
+
11
+
12
+ def is_image_file (filename ):
13
+ return any (filename .endswith (extension ) for extension in ['.png' , '.jpg' , '.jpeg' , '.JPG' , '.JPEG' , '.PNG' ])
14
+
15
+
16
+ def calculate_valid_crop_size (crop_size , upscale_factor ):
17
+ return crop_size - (crop_size % upscale_factor )
18
+
19
+
20
+ def input_transform (crop_size , upscale_factor ):
21
+ return Compose ([
22
+ CenterCrop (crop_size ),
23
+ Scale (crop_size // upscale_factor , interpolation = Image .BICUBIC )
24
+ ])
25
+
26
+
27
+ def target_transform (crop_size ):
28
+ return Compose ([
29
+ CenterCrop (crop_size )
30
+ ])
31
+
32
+
33
+ class DatasetFromFolder (Dataset ):
34
+ def __init__ (self , dataset_dir , upscale_factor , input_transform = None , target_transform = None ):
35
+ super (DatasetFromFolder , self ).__init__ ()
36
+ self .image_dir = dataset_dir + '/SRF_' + str (upscale_factor ) + '/data'
37
+ self .target_dir = dataset_dir + '/SRF_' + str (upscale_factor ) + '/target'
38
+ self .image_filenames = [join (self .image_dir , x ) for x in listdir (self .image_dir ) if is_image_file (x )]
39
+ self .target_filenames = [join (self .target_dir , x ) for x in listdir (self .target_dir ) if is_image_file (x )]
40
+ self .input_transform = input_transform
41
+ self .target_transform = target_transform
42
+
43
+ def __getitem__ (self , index ):
44
+ image , _ , _ = Image .open (self .image_filenames [index ]).convert ('YCbCr' ).split ()
45
+ target , _ , _ = Image .open (self .target_filenames [index ]).convert ('YCbCr' ).split ()
46
+ if self .input_transform :
47
+ image = self .input_transform (image )
48
+ if self .target_transform :
49
+ target = self .target_transform (target )
50
+
51
+ return image , target
52
+
53
+ def __len__ (self ):
54
+ return len (self .image_filenames )
55
+
56
+
57
+ def generate_dataset (data_type , upscale_factor ):
58
+ images_name = [x for x in listdir ('D:\\ ALL_DataSet\\ VOCdevkit\\ VOC2012\\ VOC\\ ' + data_type ) if is_image_file (x )]
59
+ crop_size = calculate_valid_crop_size (256 , upscale_factor )
60
+ lr_transform = input_transform (crop_size , upscale_factor )
61
+ hr_transform = target_transform (crop_size )
62
+
63
+ root = 'data/' + data_type
64
+ if not os .path .exists (root ):
65
+ os .makedirs (root )
66
+ path = root + '/SRF_' + str (upscale_factor )
67
+ if not os .path .exists (path ):
68
+ os .makedirs (path )
69
+ image_path = path + '/data'
70
+ if not os .path .exists (image_path ):
71
+ os .makedirs (image_path )
72
+ target_path = path + '/target'
73
+ if not os .path .exists (target_path ):
74
+ os .makedirs (target_path )
75
+
76
+ for image_name in tqdm (images_name , desc = 'generate ' + data_type + ' dataset with upscale factor = '
77
+ + str (upscale_factor ) + ' from VOC2012' ):
78
+ image = Image .open ('D:\\ ALL_DataSet\\ VOCdevkit\\ VOC2012\\ VOC\\ ' + data_type + '/' + image_name )
79
+ target = image .copy ()
80
+ image = lr_transform (image )
81
+ target = hr_transform (target )
82
+
83
+ image .save (image_path + '/' + image_name )
84
+ target .save (target_path + '/' + image_name )
85
+
86
+
87
+ if __name__ == "__main__" :
88
+ parser = argparse .ArgumentParser (description = 'Generate Super Resolution Dataset' )
89
+ parser .add_argument ('--upscale_factor' , default = 4 , type = int , help = 'super resolution upscale factor' )
90
+ opt = parser .parse_args ()
91
+ UPSCALE_FACTOR = opt .upscale_factor
92
+
93
+ generate_dataset (data_type = 'train' , upscale_factor = UPSCALE_FACTOR )
94
+ generate_dataset (data_type = 'val' , upscale_factor = UPSCALE_FACTOR )
0 commit comments