-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathutils.py
106 lines (89 loc) · 3.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import matplotlib.pyplot as plt
import torchvision.datasets as dsets
# Standard libraries
import numpy as np
import os
# PyTorch
import torch
import torch.nn as nn
y_table = np.array(
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60,
55], [14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103,
77], [24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
dtype=np.float32).T
y_table = nn.Parameter(torch.from_numpy(y_table))
#
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66],
[24, 26, 56, 99], [47, 66, 99, 99]]).T
c_table = nn.Parameter(torch.from_numpy(c_table))
def diff_round(x):
""" Differentiable rounding function
Input:
x(tensor)
Output:
x(tensor)
"""
return torch.round(x) + (x - torch.round(x))**3
def phi_diff(x, alpha):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = x.to(device)
alpha = torch.where(alpha >= 2.0, torch.tensor([2.0]).cuda(), alpha)
s = 1/(1-alpha).to(device)
k = torch.log(2/alpha -1).to(device)
phi_x = torch.tanh((x - (torch.floor(x) + 0.5)) * k) * s
x_ = (phi_x + 1)/2 + torch.floor(x)
return x_
def quality_to_factor(quality):
""" Calculate factor corresponding to quality
Input:
quality(float): Quality for jpeg compression
Output:
factor(float): Compression factor
"""
if quality < 50:
quality = 5000. / quality
else:
quality = 200. - quality*2
return quality / 100.
def imshow(img, title):
npimg = img.numpy()
fig = plt.figure(figsize = (5, 15))
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.title(title)
plt.show()
def image_folder_custom_label(root, transform, idx2label) :
# custom_label
# type : List
# index -> label
# ex) ['tench', 'goldfish', 'great_white_shark', 'tiger_shark']
old_data = dsets.ImageFolder(root=root, transform=transform)
old_classes = old_data.classes
label2idx = {}
for i, item in enumerate(idx2label) :
label2idx[item] = i
new_data = dsets.ImageFolder(root=root, transform=transform,
target_transform=lambda x : idx2label.index(old_classes[x]))
new_data.classes = idx2label
new_data.class_to_idx = label2idx
return new_data
def create_dir(dir, print_flag = False):
if not os.path.exists(dir):
os.mkdir(dir)
if print_flag:
print("Create dir {} successfully!".format(dir))
elif print_flag:
print("Directory {} is already existed. ".format(dir))
def data_clean(data_dir):
for class_name in os.listdir(data_dir):
class_path = os.path.join(data_dir, class_name)
if os.path.isfile(class_path):
os.remove(class_path)
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
if not img_name.endswith(".png"):
os.remove(img_path)