-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathall_in_one_block.py
143 lines (110 loc) · 4.35 KB
/
all_in_one_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import special_ortho_group
class AIO_Block(nn.Module):
''' Coupling block to replace the standard FrEIA implementation'''
def __init__(self, dims_in, dims_c=[],
subnet_constructor=None,
clamp=2.,
gin_block=False,
act_norm=1.,
permute_soft=False):
super().__init__()
channels = dims_in[0][0]
if dims_c:
raise ValueError('does not support conditioning yet')
self.split_len1 = channels - channels // 2
self.split_len2 = channels // 2
self.splits = [self.split_len1, self.split_len2]
self.n_pixels = dims_in[0][1] * dims_in[0][2]
self.in_channels = channels
self.clamp = clamp
self.GIN = gin_block
self.act_norm = nn.Parameter(torch.zeros(1, self.in_channels, 1, 1))
self.act_offset = nn.Parameter(torch.zeros(1, self.in_channels, 1, 1))
self.act_norm_trigger = True
if act_norm:
self.act_norm.data += np.log(act_norm)
self.act_norm_trigger = False
if permute_soft:
w = special_ortho_group.rvs(channels)
else:
w = np.zeros((channels,channels))
for i,j in enumerate(np.random.permutation(channels)):
w[i,j] = 1.
w_inv = np.linalg.inv(w)
self.w = nn.Parameter(torch.FloatTensor(w).view(channels, channels, 1, 1),
requires_grad=False)
self.w_inv = nn.Parameter(torch.FloatTensor(w_inv).view(channels, channels, 1, 1),
requires_grad=False)
self.conditional = False
condition_length = 0
self.s = subnet_constructor(self.split_len1, 2 * self.split_len2)
self.last_jac = None
def log_e(self, s):
s = self.clamp * torch.tanh(0.1 * s)
if self.GIN:
s -= torch.mean(s, dim=(1,2,3), keepdim=True)
return s
def permute(self, x, rev=False):
if rev:
return F.conv2d((x - self.act_offset) * (-self.act_norm).exp(), self.w_inv)
else:
return F.conv2d(x, self.w) * self.act_norm.exp() + self.act_offset
def affine(self, x, a, rev=False):
ch = x.shape[1]
sub_jac = self.log_e(a[:,:ch])
if not rev:
return (x * torch.exp(sub_jac) + a[:,ch:],
torch.sum(sub_jac, dim=(1,2,3)))
else:
return ((x - a[:,ch:]) * torch.exp(-sub_jac),
-torch.sum(sub_jac, dim=(1,2,3)))
def forward(self, x, c=[], rev=False):
if self.act_norm_trigger:
with torch.no_grad():
print('ActNorm triggered')
self.act_norm_trigger = False
x_out = self.forward(x)[0]
x_out = x_out.transpose(0,1).contiguous().view(self.in_channels, -1)
self.act_norm.data -= x_out.std(dim=1, unbiased=False).log().view(1, self.in_channels, 1, 1)
self.act_offset.data -= x_out.mean(dim=1).view(1, self.in_channels, 1, 1)
if rev:
x = [self.permute(x[0], rev=True)]
x1, x2 = torch.split(x[0], self.splits, dim=1)
if not rev:
a1 = self.s(x1)
x2, j2 = self.affine(x2, a1)
else: # names of x and y are swapped!
a1 = self.s(x1)
x2, j2 = self.affine(x2, a1, rev=True)
self.last_jac = j2
x_out = torch.cat((x1, x2), 1)
if not rev:
x_out = self.permute(x_out, rev=False)
return [x_out]
def jacobian(self, x, c=[], rev=False):
return self.last_jac + (-1)**rev * self.act_norm.sum() * self.n_pixels
def output_dims(self, input_dims):
return input_dims
if __name__ == '__main__':
N = 8
c = 48
x = torch.FloatTensor(128, c, N, N)
x.normal_(0,1)
constr = lambda c_in, c_out: torch.nn.Conv2d(c_in, c_out, 1)
layer = AIO_Block([(c, N, N)],
subnet_constructor=constr,
clamp=2.,
gin_block=False,
act_norm=0,
permute_soft=True)
transf = layer([x])
transf = layer([x])
x_inv = layer(transf, rev=True)[0]
err = torch.abs(x - x_inv)
print(err.max().item())
print(err.mean().item())