Skip to content

Commit f3de0bf

Browse files
committed
first commit
0 parents  commit f3de0bf

25 files changed

+2147
-0
lines changed

README.md

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SelfIR (NeurIPS 2022)
2+
3+
PyTorch implementation of [**Self-Supervised Image Restoration with Blurry and Noisy Pairs**](https://arxiv.org/)
4+
5+
[Paper (arXiv)](https://arxiv.org/)      
6+
[Video](https://arxiv.org/)
7+
8+
## 1. Framework
9+
10+
<p align="center"><img src="overview.png" width="95%"></p>
11+
<p align="center">Overview of our proposed SelfIR framework.</p>
12+
13+
- (a) Training phase of SelfIR. Sub-sampled blurry image $\mathit{g}_1(\mathbf{I}_\mathcal{B})$ and noisy image $\mathit{g}_1(\mathbf{I}_\mathcal{N})$ are taken as the inputs. $\mathit{g}_2(\mathbf{I}_\mathcal{N})$ is used for calculating the reconstruction loss $\mathcal{L}_\mathit{rec}$ and regularization loss $\mathcal{L}_\mathit{reg}$, while $\mathit{g}_1(\mathbf{I}_\mathcal{B})$ is taken for calculating auxiliary loss $\mathcal{L}_\mathit{aux}$.
14+
- (b) Example of neighbor sub-sampler. In each $2\times2$ cell, two pixels are randomly selected for respectively composing the neighboring sub-images.
15+
- (c) Testing phase of SelfIR. The blurry and noisy images can be directly taken for restoration.
16+
17+
## 2. Preparation and Datasets
18+
19+
- **Prerequisites**
20+
- Python 3.x and **PyTorch 1.6**.
21+
- OpenCV, NumPy, Pillow, tqdm, lpips, scikit-image and tensorboardX.
22+
23+
- **Dataset**
24+
- **GoPro dataset** can be downloaded from this [link](https://drive.google.com/file/d/1y4wvPdOG3mojpFCHTqLgriexhbjoWVkK/view).
25+
- **Synthetic noisy images with Gaussian noise for testing** can be downloaded from this [link](https://pan.baidu.com/s/1eA8r5QoX0cLXSfikk6XlQw?pwd=vagc). Please decompress the files according to the commands provided in its `readme.txt`.
26+
- **Synthetic noisy images with Poisson noise for testing** can be downloaded from this [link](https://pan.baidu.com/s/1tCCMxk7mlIk-27RD2_8GaA?pwd=fdw6). Please decompress the files according to the commands provided in its `readme.txt`.
27+
- The directory structure of the dataset:
28+
29+
```
30+
GOPRO_Large
31+
32+
└───train
33+
│ GOPR0372_07_00
34+
│ GOPR0372_07_01
35+
| ...
36+
37+
└───test
38+
│ GOPR0372_07_00
39+
│ GOPR0372_07_01
40+
| ...
41+
42+
└───test_noise_gauss5_50
43+
│ GOPR0372_07_00
44+
│ GOPR0372_07_01
45+
| ...
46+
47+
└───test_noise_poisson5_50
48+
│ GOPR0372_07_00
49+
│ GOPR0372_07_01
50+
| ...
51+
```
52+
53+
54+
55+
56+
## 3. Quick Start
57+
58+
### 3.1 Pre-trained models
59+
60+
- For Gaussian noise, we provide the pre-trained models in the `./ckpt/selfir_gauss_noise/` folder.
61+
62+
- For Poisson noise, we provide the pre-trained models in the `./ckpt/selfir_poisson_noise/` folder.
63+
64+
65+
### 3.2 Training
66+
67+
- Modify `dataroot`, `type` and `name` in `train.sh` and then run:
68+
69+
[`sh train.sh`](train.sh)
70+
71+
### 3.3 Testing
72+
73+
- Modify `dataroot`, `type`, `name` and `iter` in `test.sh` and then run:
74+
75+
[`sh test.sh`](test.sh)
76+
77+
### 3.4 Note
78+
79+
- You can specify which GPU to use by `--gpu_ids`, e.g., `--gpu_ids 0,1`, `--gpu_ids 3`, `--gpu_ids -1` (for CPU mode). In the default setting, all GPUs are used.
80+
- You can refer to [options](./options/base_options.py) for more arguments.
81+
82+
83+
## 4. Citation
84+
If you find it useful in your research, please consider citing:
85+
86+
@inproceedings{SelfIR,
87+
title={Self-Supervised Image Restoration with Blurry and Noisy Pairs},
88+
author={Zhang, Zhilu and Xu, Rongjian and Liu, Ming and Yan, Zifei and Zuo, Wangmeng},
89+
booktitle={NeurIPS},
90+
year={2022}
91+
}
92+
93+
## 5. Acknowledgement
94+
95+
This repo is built upon the framework of [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix), and we borrow some code from [Neighbor2Neighbor](https://github.com/TaoHuang2018/Neighbor2Neighbor), thanks for their excellent work!
4.13 MB
Binary file not shown.
4.13 MB
Binary file not shown.

data/__init__.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import importlib
2+
import torch.utils.data
3+
from data.base_dataset import BaseDataset
4+
5+
6+
def find_dataset_using_name(dataset_name, split='train'):
7+
dataset_filename = "data." + dataset_name + "_dataset"
8+
datasetlib = importlib.import_module(dataset_filename)
9+
10+
dataset = None
11+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
12+
for name, cls in datasetlib.__dict__.items():
13+
if name.lower() == target_dataset_name.lower() \
14+
and issubclass(cls, BaseDataset):
15+
dataset = cls
16+
17+
if dataset is None:
18+
raise NotImplementedError("In %s.py, there should be a subclass of "
19+
"BaseDataset with class name that matches %s in "
20+
"lowercase." % (dataset_filename, target_dataset_name))
21+
return dataset
22+
23+
24+
def create_dataset(dataset_name, split, opt):
25+
data_loader = CustomDatasetDataLoader(dataset_name, split, opt)
26+
dataset = data_loader.load_data()
27+
return dataset
28+
29+
30+
class CustomDatasetDataLoader():
31+
def __init__(self, dataset_name, split, opt):
32+
self.opt = opt
33+
dataset_class = find_dataset_using_name(dataset_name, split)
34+
self.dataset = dataset_class(opt, split, dataset_name)
35+
self.imio = self.dataset.imio
36+
print("dataset [%s(%s)] created" % (dataset_name, split))
37+
self.dataloader = torch.utils.data.DataLoader(
38+
self.dataset,
39+
batch_size=opt.batch_size if split=='train' else 1,
40+
shuffle=opt.shuffle and split=='train',
41+
num_workers=int(opt.num_dataloader),
42+
drop_last=opt.drop_last)
43+
44+
def load_data(self):
45+
return self
46+
47+
def __len__(self):
48+
"""Return the number of data in the dataset"""
49+
return min(len(self.dataset), self.opt.max_dataset_size)
50+
51+
def __iter__(self):
52+
"""Return a batch of data"""
53+
for i, data in enumerate(self.dataloader):
54+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
55+
break
56+
yield data
57+

data/base_dataset.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch.utils.data as data
2+
from abc import ABC, abstractmethod
3+
4+
5+
class BaseDataset(data.Dataset, ABC):
6+
def __init__(self, opt, split, dataset_name):
7+
self.opt = opt
8+
self.split = split
9+
self.root = opt.dataroot
10+
self.dataset_name = dataset_name.lower()
11+
12+
@abstractmethod
13+
def __len__(self):
14+
return 0
15+
16+
@abstractmethod
17+
def __getitem__(self, index):
18+
pass
19+

data/gopro_dataset.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import numpy as np
2+
import os
3+
from data.base_dataset import BaseDataset
4+
from .imlib import imlib
5+
from multiprocessing.dummy import Pool
6+
from tqdm import tqdm
7+
from util.util import augment
8+
import random
9+
10+
11+
# GoPro dataset
12+
class GoProDataset(BaseDataset):
13+
def __init__(self, opt, split='train', dataset_name='GoPro'):
14+
super(GoProDataset, self).__init__(opt, split, dataset_name)
15+
16+
if self.root == '':
17+
rootlist = ['/Data/dataset/GOPRO_Large/']
18+
for root in rootlist:
19+
if os.path.isdir(root):
20+
self.root = root
21+
break
22+
23+
self.batch_size = opt.batch_size
24+
self.patch_size = opt.patch_size
25+
self.mode = opt.mode # RGB, Y or L=
26+
self.imio = imlib(self.mode, lib=opt.imlib)
27+
self.names, self.blur_dirs, self.gt_dirs = self._get_image_dir(self.root, split)
28+
29+
if split == 'train':
30+
self._getitem = self._getitem_train
31+
self.len_data = 500 * 16 # 500 * self.batch_size
32+
elif split == 'val':
33+
self._getitem = self._getitem_test
34+
self.len_data = len(self.names)
35+
elif split == 'test':
36+
self._getitem = self._getitem_test
37+
self.len_data = len(self.names)
38+
else:
39+
raise ValueError
40+
41+
self.blur_images = [0] * len(self.names)
42+
self.gt_images = [0] * len(self.names)
43+
read_images(self)
44+
45+
def __getitem__(self, index):
46+
return self._getitem(index)
47+
48+
def __len__(self):
49+
return self.len_data
50+
51+
def _getitem_train(self, idx):
52+
idx = idx % len(self.names)
53+
54+
blur_img = self.blur_images[idx]
55+
gt_img = self.gt_images[idx]
56+
blur_img, gt_img = self._crop_patch(blur_img, gt_img)
57+
58+
blur_img, gt_img = augment(blur_img, gt_img)
59+
60+
blur_img = np.float32(blur_img) / 255
61+
gt_img = np.float32(gt_img) / 255
62+
63+
return {'gt_noise': gt_img,
64+
'blur_img': blur_img,
65+
'gt_img': gt_img,
66+
'fname': self.names[idx]}
67+
68+
def _getitem_test(self, idx):
69+
70+
blur_img = self.blur_images[idx]
71+
gt_img = self.gt_images[idx]
72+
73+
blur_img = np.float32(blur_img) / 255
74+
gt_img = np.float32(gt_img) / 255
75+
76+
noise_root = self.gt_dirs[idx].replace('sharp', 'npy')
77+
noise_root = noise_root.replace('test', 'test_noise_' + self.opt.noisetype)
78+
noise_file = noise_root[:-3] + 'npy'
79+
gt_noise = np.float32(np.load(noise_file, allow_pickle=True))
80+
81+
return {'gt_noise': gt_noise,
82+
'blur_img': blur_img,
83+
'gt_img': gt_img,
84+
'fname': self.names[idx]}
85+
86+
def _crop_patch(self, blur, gt):
87+
ih, iw = blur.shape[-2:]
88+
p = self.patch_size
89+
pw = random.randrange(0, iw - p + 1)
90+
ph = random.randrange(0, ih - p + 1)
91+
return blur[..., ph:ph+p, pw:pw+p], \
92+
gt[..., ph:ph+p, pw:pw+p]
93+
94+
def _get_image_dir(self, dataroot, split=None):
95+
blur_dirs = []
96+
gt_dirs = []
97+
image_names = []
98+
99+
if split == 'train' or split == 'test':
100+
for scene_file in os.listdir(dataroot + split + '/'):
101+
for image_file in os.listdir(dataroot + split + '/' + scene_file + '/sharp/'):
102+
image_names.append(scene_file + '-' + image_file)
103+
blur_dirs.append(dataroot + split + '/' + scene_file + '/blur_gamma/' + image_file)
104+
gt_dirs.append(dataroot + split + '/' + scene_file + '/sharp/' + image_file)
105+
elif split == 'val':
106+
for scene_file in os.listdir(dataroot + 'test/'):
107+
for image_file in os.listdir(dataroot + 'test/' + scene_file + '/sharp/'):
108+
image_names.append(scene_file + '-' + image_file)
109+
blur_dirs.append(dataroot + 'test/' + scene_file + '/blur_gamma/' + image_file)
110+
gt_dirs.append(dataroot + 'test/' + scene_file + '/sharp/' + image_file)
111+
break
112+
else:
113+
raise ValueError
114+
115+
image_names = sorted(image_names)
116+
blur_dirs = sorted(blur_dirs)
117+
gt_dirs = sorted(gt_dirs)
118+
119+
return image_names, blur_dirs, gt_dirs
120+
121+
122+
def iter_obj(num, objs):
123+
for i in range(num):
124+
yield (i, objs)
125+
126+
def imreader(arg):
127+
i, obj = arg
128+
for _ in range(3):
129+
try:
130+
obj.blur_images[i] = obj.imio.read(obj.blur_dirs[i])
131+
obj.gt_images[i] = obj.imio.read(obj.gt_dirs[i])
132+
failed = False
133+
break
134+
except:
135+
failed = True
136+
if failed: print('%s fails!' % obj.names[i])
137+
138+
def read_images(obj):
139+
# may use `from multiprocessing import Pool` instead, but less efficient and
140+
# NOTE: `multiprocessing.Pool` will duplicate given object for each process.
141+
print('Starting to load images via multiple imreaders')
142+
pool = Pool() # use all threads by default
143+
for _ in tqdm(pool.imap(imreader, iter_obj(len(obj.names), obj)), total=len(obj.names)):
144+
pass
145+
pool.close()
146+
pool.join()
147+
148+
if __name__ == '__main__':
149+
pass

0 commit comments

Comments
 (0)