-
Notifications
You must be signed in to change notification settings - Fork 352
/
common.py
75 lines (45 loc) · 1.54 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import tensorflow as tf
DIV2K_RGB_MEAN = np.array([0.4488, 0.4371, 0.4040]) * 255
def resolve_single(model, lr):
return resolve(model, tf.expand_dims(lr, axis=0))[0]
def resolve(model, lr_batch):
lr_batch = tf.cast(lr_batch, tf.float32)
sr_batch = model(lr_batch)
sr_batch = tf.clip_by_value(sr_batch, 0, 255)
sr_batch = tf.round(sr_batch)
sr_batch = tf.cast(sr_batch, tf.uint8)
return sr_batch
def evaluate(model, dataset):
psnr_values = []
for lr, hr in dataset:
sr = resolve(model, lr)
psnr_value = psnr(hr, sr)[0]
psnr_values.append(psnr_value)
return tf.reduce_mean(psnr_values)
# ---------------------------------------
# Normalization
# ---------------------------------------
def normalize(x, rgb_mean=DIV2K_RGB_MEAN):
return (x - rgb_mean) / 127.5
def denormalize(x, rgb_mean=DIV2K_RGB_MEAN):
return x * 127.5 + rgb_mean
def normalize_01(x):
"""Normalizes RGB images to [0, 1]."""
return x / 255.0
def normalize_m11(x):
"""Normalizes RGB images to [-1, 1]."""
return x / 127.5 - 1
def denormalize_m11(x):
"""Inverse of normalize_m11."""
return (x + 1) * 127.5
# ---------------------------------------
# Metrics
# ---------------------------------------
def psnr(x1, x2):
return tf.image.psnr(x1, x2, max_val=255)
# ---------------------------------------
# See https://arxiv.org/abs/1609.05158
# ---------------------------------------
def pixel_shuffle(scale):
return lambda x: tf.nn.depth_to_space(x, scale)