Skip to content

Commit d548897

Browse files
YmshiYmshi
Ymshi
authored and
Ymshi
committed
add ReLLIE
1 parent 8713c39 commit d548897

File tree

7 files changed

+998
-3
lines changed

7 files changed

+998
-3
lines changed

Diff for: README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ Welcome to join the project!
1515

1616

1717
## Recent updates
18-
18+
#### 2022.4.29 add ReLLIE (**need to improve**)
1919
#### 2022.4.19 add DALE
20-
#### 2022.4.17 add DRBN(stage1, stag2) & SGM
20+
#### 2022.4.17 add DRBN(stage1, stage2) & SGM
2121
#### 2022.4.16 add EnlightenGAN
2222
#### 2022.4.14 add ZeroDCE & ZeroDCE++
2323
#### 2022.4.12 start project
@@ -60,7 +60,7 @@ Then you can see the results in the folder "output".
6060
|⭕️ DSLR |⭕️ AGLLNet | | |
6161
|⭕️ StableLLVE| | | |
6262
|⭕️ LPNet| | | |
63-
|⭕️ ReLLIE| | | |
63+
| ReLLIE| | | |
6464
|⭕️ RUAS| | | |
6565
|⭕️ RRDNet| | | |
6666

Diff for: ReLLIE/MyFCN.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import chainer
2+
from chainer import Variable
3+
import chainer.links as L
4+
import chainer.functions as F
5+
import numpy as np
6+
import math
7+
import cv2
8+
from chainer.links.caffe import CaffeFunction
9+
import chainerrl
10+
from chainerrl.agents import a3c
11+
12+
13+
class DilatedConvBlock(chainer.Chain):
14+
15+
def __init__(self, d_factor):
16+
super(DilatedConvBlock, self).__init__(
17+
diconv=L.DilatedConvolution2D(in_channels=64, out_channels=64, ksize=3, stride=1, pad=d_factor, dilate=d_factor, nobias=False),
18+
)
19+
20+
self.train = True
21+
22+
def __call__(self, x):
23+
h = F.relu(self.diconv(x))
24+
return h
25+
26+
27+
class MyFcn(chainer.Chain, a3c.A3CModel):
28+
29+
def __init__(self, n_actions):
30+
w = chainer.initializers.HeNormal()
31+
super(MyFcn, self).__init__(
32+
conv1=L.Convolution2D(3, 64, 3, stride=1, pad=1, nobias=False),
33+
diconv2=DilatedConvBlock(2),
34+
diconv3=DilatedConvBlock(3),
35+
diconv4=DilatedConvBlock(4),
36+
diconv5_pi=DilatedConvBlock(3),
37+
diconv6_pi=DilatedConvBlock(2),
38+
conv7_r_pi=chainerrl.policies.SoftmaxPolicy(
39+
L.Convolution2D(64, n_actions, 3, stride=1, pad=1, nobias=False)),
40+
conv7_g_pi=chainerrl.policies.SoftmaxPolicy(
41+
L.Convolution2D(64, n_actions, 3, stride=1, pad=1, nobias=False)),
42+
conv7_b_pi=chainerrl.policies.SoftmaxPolicy(
43+
L.Convolution2D(64, n_actions, 3, stride=1, pad=1, nobias=False)),
44+
diconv5_V=DilatedConvBlock(3),
45+
diconv6_V=DilatedConvBlock(2),
46+
conv7_V=L.Convolution2D(64, 1, 3, stride=1, pad=1, nobias=False),
47+
)
48+
self.train = True
49+
50+
def pi_and_v(self, x):
51+
h = F.relu(self.conv1(x))
52+
h = self.diconv2(h)
53+
h = self.diconv3(h)
54+
h = self.diconv4(h)
55+
h_pi = self.diconv5_pi(h)
56+
h_pi = self.diconv6_pi(h_pi)
57+
pout_r = self.conv7_r_pi(h_pi)
58+
pout_g = self.conv7_g_pi(h_pi)
59+
pout_b = self.conv7_b_pi(h_pi)
60+
h_V = self.diconv5_V(h)
61+
h_V = self.diconv6_V(h_V)
62+
vout = self.conv7_V(h_V)
63+
64+
return pout_r, pout_g, pout_b, vout
65+
66+
67+
class MyFcn_denoise(chainer.Chain, a3c.A3CModel):
68+
69+
def __init__(self, n_actions):
70+
w = chainer.initializers.HeNormal()
71+
#net = CaffeFunction('../initial_weight/zhang_cvpr17_denoise_15_gray.caffemodel')
72+
super(MyFcn_denoise, self).__init__(
73+
conv1=L.Convolution2D(3, 64, 3, stride=1, pad=1, nobias=False),
74+
diconv2=DilatedConvBlock(2),
75+
diconv3=DilatedConvBlock(3),
76+
diconv4=DilatedConvBlock(4),
77+
diconv5_pi=DilatedConvBlock(3),
78+
diconv6_pi=DilatedConvBlock(2),
79+
conv7_pi=chainerrl.policies.SoftmaxPolicy(L.Convolution2D(64, n_actions, 3, stride=1, pad=1, nobias=False)),
80+
diconv5_V=DilatedConvBlock(3),
81+
diconv6_V=DilatedConvBlock(2),
82+
conv7_V=L.Convolution2D(64, 1, 3, stride=1, pad=1, nobias=False),
83+
)
84+
self.train = True
85+
86+
def pi_and_v(self, x):
87+
h = F.relu(self.conv1(x))
88+
h = self.diconv2(h)
89+
h = self.diconv3(h)
90+
h = self.diconv4(h)
91+
h_pi = self.diconv5_pi(h)
92+
h_pi = self.diconv6_pi(h_pi)
93+
de = self.conv7_pi(h_pi)
94+
#pout = np.concatenate((pout_r,pout_g,pout_b), axis=1)
95+
h_V = self.diconv5_V(h)
96+
h_V = self.diconv6_V(h_V)
97+
vout = self.conv7_V(h_V)
98+
99+
return de, vout

Diff for: ReLLIE/convert_model.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
3+
if __name__ == '__main__':
4+
weight = torch.load('pre_weights/ReLLIE/net_rgb.pth')
5+
new_weight = {}
6+
for k, v in weight.items():
7+
new_weight[k[7:]] = v
8+
torch.save(new_weight, 'weights/ReLLIE/weight.pth')

Diff for: ReLLIE/infer.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import torch
2+
import torchvision
3+
from torchvision import transforms
4+
from glob import glob
5+
import os
6+
from PIL import Image
7+
import time
8+
import sys
9+
from pathlib import Path
10+
sys.path.append(str(Path(__file__).resolve().parents[1]))
11+
from utils.util import *
12+
13+
import MyFCN
14+
import pixelwise_a3c_de
15+
import pixelwise_a3c_el
16+
17+
model_name = 'ReLLIE'
18+
in_path = './input/test.png'
19+
out_path = './output/' + model_name
20+
21+
EPISODE_LEN = 3 # need to be modified by yourself
22+
LEARNING_RATE = 0.0005
23+
GAMMA = 1.05 # discount factor
24+
N_ACTIONS = 27
25+
MOVE_RANGE = 27 # number of actions that move the pixel values. e.g., when MOVE_RANGE=3, there are three actions: pixel_value+=1, +=0, -=1.
26+
27+
def load_model():
28+
from model import Model
29+
model = Model()
30+
model.load_weight('weights/ReLLIE/weight.pth')
31+
model = model.eval().cuda()
32+
33+
model_el = MyFCN.MyFcn(N_ACTIONS)
34+
optimizer_el = pixelwise_a3c_el.chainer.optimizers.Adam(alpha=LEARNING_RATE)
35+
optimizer_el.setup(model_el)
36+
agent_el = pixelwise_a3c_el.PixelWiseA3C(model_el, optimizer_el, EPISODE_LEN, GAMMA)
37+
pixelwise_a3c_el.chainer.serializers.load_npz('./weights/ReLLIE/pretrained/model.npz', agent_el.model)
38+
agent_el.act_deterministically = True
39+
agent_el.model.to_gpu()
40+
41+
model_de = MyFCN.MyFcn_denoise(2)
42+
optimizer_de = pixelwise_a3c_de.chainer.optimizers.Adam(alpha=LEARNING_RATE)
43+
optimizer_de.setup(model_de)
44+
agent_de = pixelwise_a3c_de.PixelWiseA3C(model_de, optimizer_de, EPISODE_LEN, GAMMA)
45+
pixelwise_a3c_de.chainer.serializers.load_npz('./weights/ReLLIE/pretrained/init_denoising.npz', agent_de.model)
46+
agent_de.act_deterministically = True
47+
agent_de.model.to_gpu()
48+
49+
return model.eval().cuda(), agent_el, agent_de
50+
51+
52+
def load_data_paths():
53+
global in_path, out_path
54+
if os.path.isfile(in_path):
55+
input_paths = [in_path]
56+
in_path = os.path.dirname(in_path)
57+
elif os.path.isdir(in_path):
58+
input_paths = []
59+
for root, dirs, files in os.walk(in_path):
60+
for name in files:
61+
for ext in ['.jpg', '.png', '.jpeg', '.bmp']:
62+
if name.lower().endswith(ext):
63+
input_paths.append(os.path.join(root, name))
64+
return input_paths
65+
66+
67+
class State_de():
68+
def __init__(self, move_range, model):
69+
self.cur_img = None
70+
self.move_range = move_range
71+
self.net = model
72+
73+
def reset(self, x):
74+
self.low_img = x
75+
self.cur_img = x
76+
torch.clamp_(self.low_img, 3 / 255, 1)
77+
78+
def step_el(self, act):
79+
move = torch.tensor(act, dtype=self.cur_img.dtype, device=self.cur_img.device)
80+
moves = (move - 6) / 20
81+
moved_image = self.cur_img + (0.1 * moves + 0.9 * moves[:,0:1,:,:]) * self.cur_img * (1 - self.cur_img)
82+
self.cur_img = 0.8 * moved_image + 0.2 * self.cur_img
83+
84+
def step_de(self, act_b):
85+
# noise level map
86+
nsigma = (self.cur_img - self.low_img) / self.low_img
87+
nsigma = nsigma.max() * 2 * (nsigma - nsigma.min()) / (nsigma.max() - nsigma.min())
88+
torch.clamp_(nsigma, 0)
89+
nsigma = nsigma / 255.
90+
nsigma = nsigma[:, :, ::2, ::2]
91+
92+
# Estimate noise and subtract it to the input image
93+
estim_noise = self.net(self.cur_img, nsigma)
94+
self.cur_img = torch.clamp(self.cur_img - estim_noise, 0., 1.)
95+
96+
97+
def test(raw_x, agent_el, agent_de, model):
98+
current_state = State_de(MOVE_RANGE, model)
99+
current_state.reset(raw_x)
100+
101+
for t in range(EPISODE_LEN):
102+
action_el = agent_el.act(current_state.cur_img.cpu().detach().numpy())
103+
current_state.step_el(action_el)
104+
if t > 4:
105+
action_de = agent_de.act(current_state.cur_img.cpu().detach().numpy())
106+
current_state.step_de(action_de)
107+
108+
agent_de.stop_episode()
109+
110+
return current_state.cur_img
111+
112+
def inference(model, input_paths):
113+
global in_path, out_path
114+
total_time = 0
115+
ts = transforms.ToTensor()
116+
model, agent_el, agent_de = model
117+
118+
with torch.no_grad():
119+
for input_path in input_paths:
120+
output_path = input_path.replace(in_path, out_path)
121+
if not os.path.exists(os.path.dirname(output_path)):
122+
os.makedirs(os.path.dirname(output_path))
123+
124+
img = Image.open(input_path)
125+
img = ts(img).unsqueeze(0).cuda()
126+
127+
img, h, w = padding(img, 8)
128+
tic = time.time()
129+
output = test(img, agent_el, agent_de, model)
130+
toc = time.time()
131+
output = unpadding(output, h, w)
132+
total_time += toc - tic
133+
134+
torchvision.utils.save_image(output, output_path)
135+
print('{} Total time: {:.4f}s Speed: {:.4f}s/img'.format(model_name, total_time, total_time / len(input_paths)))
136+
137+
138+
if __name__ == '__main__':
139+
model = load_model()
140+
input_paths = load_data_paths()
141+
inference(model, input_paths)
142+

Diff for: ReLLIE/model.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class IntermediateDnCNN(nn.Module):
7+
r"""Implements the middel part of the FFDNet architecture, which
8+
is basically a DnCNN net
9+
"""
10+
def __init__(self, input_features, middle_features, num_conv_layers):
11+
super(IntermediateDnCNN, self).__init__()
12+
self.kernel_size = 3
13+
self.padding = 1
14+
self.input_features = input_features
15+
self.num_conv_layers = num_conv_layers
16+
self.middle_features = middle_features
17+
if self.input_features == 5:
18+
self.output_features = 4 #Grayscale image
19+
elif self.input_features == 15:
20+
self.output_features = 12 #RGB image
21+
else:
22+
raise Exception('Invalid number of input features')
23+
24+
layers = []
25+
layers.append(nn.Conv2d(in_channels=self.input_features,\
26+
out_channels=self.middle_features,\
27+
kernel_size=self.kernel_size,\
28+
padding=self.padding,\
29+
bias=False))
30+
layers.append(nn.ReLU(inplace=True))
31+
for _ in range(self.num_conv_layers-2):
32+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
33+
out_channels=self.middle_features,\
34+
kernel_size=self.kernel_size,\
35+
padding=self.padding,\
36+
bias=False))
37+
layers.append(nn.BatchNorm2d(self.middle_features))
38+
layers.append(nn.ReLU(inplace=True))
39+
layers.append(nn.Conv2d(in_channels=self.middle_features,\
40+
out_channels=self.output_features,\
41+
kernel_size=self.kernel_size,\
42+
padding=self.padding,\
43+
bias=False))
44+
self.itermediate_dncnn = nn.Sequential(*layers)
45+
def forward(self, x):
46+
out = self.itermediate_dncnn(x)
47+
return out
48+
49+
50+
class Model(nn.Module):
51+
r"""Implements the FFDNet architecture
52+
"""
53+
def __init__(self, num_input_channels=3):
54+
super(Model, self).__init__()
55+
self.num_input_channels = num_input_channels
56+
if self.num_input_channels == 1:
57+
# Grayscale image
58+
self.num_feature_maps = 64
59+
self.num_conv_layers = 15
60+
self.downsampled_channels = 5
61+
elif self.num_input_channels == 3:
62+
# RGB image
63+
self.num_feature_maps = 96
64+
self.num_conv_layers = 12
65+
self.downsampled_channels = 15
66+
else:
67+
raise Exception('Invalid number of input features')
68+
69+
self.intermediate_dncnn = IntermediateDnCNN(\
70+
input_features=self.downsampled_channels,\
71+
middle_features=self.num_feature_maps,\
72+
num_conv_layers=self.num_conv_layers)
73+
74+
def forward(self, x, noise_sigma):
75+
x = F.pixel_unshuffle(x, 2)
76+
x = torch.cat([noise_sigma, x], 1)
77+
x = self.intermediate_dncnn(x)
78+
x = F.pixel_shuffle(x, 2)
79+
return x
80+
81+
def load_weight(self, path):
82+
r"""Loads the weights of the FFDNet model from a file
83+
"""
84+
self.load_state_dict(torch.load(path))
85+
86+
if __name__ == '__main__':
87+
model = Model()
88+
x = torch.randn(1, 3, 128, 128)
89+
n = torch.randn(1, 3, 64, 64)
90+
y = model(x, n)
91+
print(y.shape)

0 commit comments

Comments
 (0)