-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
143 lines (112 loc) · 4.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import tensorflow as tf
import glob
import numpy as np
import skimage.measure
def _parse_function_CBD(example_proto):
keys_to_features = {'Noisy': tf.FixedLenFeature([], tf.string),
'GT': tf.FixedLenFeature([], tf.string),
'Sigma': tf.FixedLenFeature([], tf.string)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
noisy = parsed_features['Noisy']
noisy = tf.decode_raw(noisy, tf.float32)
noisy = tf.reshape(noisy, [256, 256, 3])
gt = parsed_features['GT']
gt = tf.decode_raw(gt, tf.float32)
gt = tf.reshape(gt, [256, 256, 3])
sigma = parsed_features['Sigma']
sigma = tf.decode_raw(sigma, tf.float32)
sigma = tf.reshape(sigma, [256, 256, 3])
return noisy, gt, sigma
def load_tfrecords_2(tfrecords_file, n_shuffle=1000, batch_size=64):
dataset = tf.data.TFRecordDataset(tfrecords_file)
dataset = dataset.map(_parse_function_CBD)
dataset = dataset.shuffle(n_shuffle)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
x, y, z = iterator.get_next()
return x, y, z
def _parse_function(example_proto):
keys_to_features = {'Noisy': tf.FixedLenFeature([], tf.string),
'GT': tf.FixedLenFeature([], tf.string)}
parsed_features = tf.parse_single_example(example_proto, keys_to_features)
noisy = parsed_features['Noisy']
noisy = tf.divide(tf.cast(tf.decode_raw(noisy, tf.uint8), tf.float32), 255.)
noisy = tf.reshape(noisy, [256, 256, 3])
gt = parsed_features['GT']
gt = tf.divide(tf.cast(tf.decode_raw(gt, tf.uint8), tf.float32), 255.)
gt = tf.reshape(gt, [256, 256, 3])
return noisy, gt
def load_tfrecords(tfrecords_file, n_shuffle=1000, batch_size=64):
dataset = tf.data.TFRecordDataset(tfrecords_file)
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(n_shuffle)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
x, y = iterator.get_next()
return x, y
def tf_psnr(pred, ref):
# assert pixel value range is 0-1
mse = tf.losses.mean_squared_error(labels=ref * 255.0, predictions=pred * 255.0)
return 10.0 * (tf.log(255.0 ** 2 / mse) / tf.log(10.0))
def batch_PSNR(noisy, ref, data_range):
PSNR = 0.0
for i in range(noisy.shape[0]):
PSNR += skimage.measure.compare_psnr(ref[i, :, :, :], noisy[i, :, :, :], data_range=data_range)
return (PSNR / noisy.shape[0])
def dataaugment(patches):
for i in range(patches.shape[0]):
patches[i] = datatransform(patches[i], np.random.randint(0, 8))
return patches
def dataaugment_idx(patches, idx):
for i in range(patches.shape[0]):
patches[i] = datatransform(patches[i], idx[i])
return patches
def datatransform(img, mode):
if mode < 4:
img = np.rot90(img, 1)
mode = mode % 4
if mode == 0:
pass
if mode == 1:
img = np.fliplr(img)
if mode == 2:
img = np.flipud(img)
if mode == 3:
img = np.flipud(np.fliplr(img))
return img
def datatransform_inv(img, mode):
mode_ = mode % 4
if mode_ == 0:
pass
if mode_ == 1:
img = np.fliplr(img)
if mode_ == 2:
img = np.flipud(img)
if mode_ == 3:
img = np.fliplr(np.flipud(img))
if mode < 4:
img = np.rot90(img, -1)
return img
def write_description(opt):
with open('./logs_da/' + opt.name + '/model.txt', 'w') as f:
f.write(str(opt._get_kwargs()))
def get_paramsnum():
total_parameters = 0
for variable in tf.trainable_variables():
# shape is an array of tf.Dimension
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
print(total_parameters)
return
def batch_PSNR_255(noisy, ref):
PSNR = 0.0
for i in range(noisy.shape[0]):
ref_i = np.round(255*ref[i, :, :, :]).astype(np.uint8)
noisy_i = np.round(255*noisy[i, :, :, :]).astype(np.uint8)
PSNR += skimage.measure.compare_psnr(ref_i, noisy_i, data_range=255)
return (PSNR / noisy.shape[0])