Skip to content

Commit 0e3d760

Browse files
committed
Fix SAM formatting.
1 parent 493f97d commit 0e3d760

File tree

4 files changed

+533
-411
lines changed

4 files changed

+533
-411
lines changed
+139-113
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,139 @@
1-
from __future__ import division
2-
import keras.backend as K
3-
import theano.tensor as T
4-
from keras.layers import Layer, InputSpec
5-
from keras import initializations, regularizers, constraints
6-
import theano
7-
import numpy as np
8-
floatX = theano.config.floatX
9-
10-
11-
class LearningPrior(Layer):
12-
def __init__(self, nb_gaussian, init='normal', weights=None,
13-
W_regularizer=None, activity_regularizer=None,
14-
W_constraint=None, **kwargs):
15-
self.nb_gaussian = nb_gaussian
16-
self.init = initializations.get(init, dim_ordering='th')
17-
18-
self.W_regularizer = regularizers.get(W_regularizer)
19-
self.activity_regularizer = regularizers.get(activity_regularizer)
20-
21-
self.W_constraint = constraints.get(W_constraint)
22-
23-
self.input_spec = [InputSpec(ndim=4)]
24-
self.initial_weights = weights
25-
super(LearningPrior, self).__init__(**kwargs)
26-
27-
def build(self, input_shape):
28-
self.W_shape = (self.nb_gaussian*4, )
29-
self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
30-
31-
self.trainable_weights = [self.W]
32-
33-
self.regularizers = []
34-
if self.W_regularizer:
35-
self.W_regularizer.set_param(self.W)
36-
self.regularizers.append(self.W_regularizer)
37-
38-
if self.activity_regularizer:
39-
self.activity_regularizer.set_layer(self)
40-
self.regularizers.append(self.activity_regularizer)
41-
42-
if self.initial_weights is not None:
43-
self.set_weights(self.initial_weights)
44-
del self.initial_weights
45-
46-
self.constraints = {}
47-
if self.W_constraint:
48-
self.constraints[self.W] = self.W_constraint
49-
50-
def get_output_shape_for(self, input_shape):
51-
self.b_s = input_shape[0]
52-
self.height = input_shape[2]
53-
self.width = input_shape[3]
54-
55-
return self.b_s, self.nb_gaussian, self.height, self.width
56-
57-
def call(self, x, mask=None):
58-
mu_x = self.W[:self.nb_gaussian]
59-
mu_y = self.W[self.nb_gaussian:self.nb_gaussian*2]
60-
sigma_x = self.W[self.nb_gaussian*2:self.nb_gaussian*3]
61-
sigma_y = self.W[self.nb_gaussian*3:]
62-
63-
self.b_s = x.shape[0]
64-
self.height = x.shape[2]
65-
self.width = x.shape[3]
66-
67-
e = self.height / self.width
68-
e1 = (1 - e) / 2
69-
e2 = e1 + e
70-
71-
mu_x = K.clip(mu_x, 0.25, 0.75)
72-
mu_y = K.clip(mu_y, 0.35, 0.65)
73-
74-
sigma_x = K.clip(sigma_x, 0.1, 0.9)
75-
sigma_y = K.clip(sigma_y, 0.2, 0.8)
76-
77-
x_t = T.dot(T.ones((self.height, 1)), self._linspace(0, 1.0, self.width).dimshuffle('x', 0))
78-
y_t = T.dot(self._linspace(e1, e2, self.height).dimshuffle(0, 'x'), T.ones((1, self.width)))
79-
80-
x_t = K.repeat_elements(K.expand_dims(x_t, dim=-1), self.nb_gaussian, axis=-1)
81-
y_t = K.repeat_elements(K.expand_dims(y_t, dim=-1), self.nb_gaussian, axis=-1)
82-
83-
gaussian = 1 / (2 * np.pi * sigma_x * sigma_y + K.epsilon()) * \
84-
T.exp(-((x_t - mu_x) ** 2 / (2 * sigma_x ** 2 + K.epsilon()) +
85-
(y_t - mu_y) ** 2 / (2 * sigma_y ** 2 + K.epsilon())))
86-
87-
gaussian = K.permute_dimensions(gaussian, (2, 0, 1))
88-
max_gauss = K.repeat_elements(K.expand_dims(K.repeat_elements(K.expand_dims(K.max(K.max(gaussian, axis=1), axis=1)), self.height, axis=-1)), self.width, axis=-1)
89-
gaussian = gaussian / max_gauss
90-
91-
output = K.repeat_elements(K.expand_dims(gaussian, dim=0), self.b_s, axis=0)
92-
93-
return output
94-
95-
@staticmethod
96-
def _linspace(start, stop, num):
97-
# produces results identical to:
98-
# np.linspace(start, stop, num)
99-
start = T.cast(start, floatX)
100-
stop = T.cast(stop, floatX)
101-
num = T.cast(num, floatX)
102-
step = (stop - start) / (num - 1)
103-
return T.arange(num, dtype=floatX) * step + start
104-
105-
def get_config(self):
106-
config = {'nb_gaussian': self.nb_gaussian,
107-
'init': self.init.__name__,
108-
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
109-
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
110-
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
111-
}
112-
base_config = super(LearningPrior, self).get_config()
113-
return dict(list(base_config.items()) + list(config.items()))
1+
from __future__ import division
2+
import keras.backend as K
3+
import theano.tensor as T
4+
from keras.layers import Layer, InputSpec
5+
from keras import initializations, regularizers, constraints
6+
import theano
7+
import numpy as np
8+
floatX = theano.config.floatX
9+
10+
11+
class LearningPrior(Layer):
12+
def __init__(self,
13+
nb_gaussian,
14+
init='normal',
15+
weights=None,
16+
W_regularizer=None,
17+
activity_regularizer=None,
18+
W_constraint=None,
19+
**kwargs):
20+
self.nb_gaussian = nb_gaussian
21+
self.init = initializations.get(init, dim_ordering='th')
22+
23+
self.W_regularizer = regularizers.get(W_regularizer)
24+
self.activity_regularizer = regularizers.get(activity_regularizer)
25+
26+
self.W_constraint = constraints.get(W_constraint)
27+
28+
self.input_spec = [InputSpec(ndim=4)]
29+
self.initial_weights = weights
30+
super(LearningPrior, self).__init__(**kwargs)
31+
32+
def build(self, input_shape):
33+
self.W_shape = (self.nb_gaussian * 4, )
34+
self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
35+
36+
self.trainable_weights = [self.W]
37+
38+
self.regularizers = []
39+
if self.W_regularizer:
40+
self.W_regularizer.set_param(self.W)
41+
self.regularizers.append(self.W_regularizer)
42+
43+
if self.activity_regularizer:
44+
self.activity_regularizer.set_layer(self)
45+
self.regularizers.append(self.activity_regularizer)
46+
47+
if self.initial_weights is not None:
48+
self.set_weights(self.initial_weights)
49+
del self.initial_weights
50+
51+
self.constraints = {}
52+
if self.W_constraint:
53+
self.constraints[self.W] = self.W_constraint
54+
55+
def get_output_shape_for(self, input_shape):
56+
self.b_s = input_shape[0]
57+
self.height = input_shape[2]
58+
self.width = input_shape[3]
59+
60+
return self.b_s, self.nb_gaussian, self.height, self.width
61+
62+
def call(self, x, mask=None):
63+
mu_x = self.W[:self.nb_gaussian]
64+
mu_y = self.W[self.nb_gaussian:self.nb_gaussian * 2]
65+
sigma_x = self.W[self.nb_gaussian * 2:self.nb_gaussian * 3]
66+
sigma_y = self.W[self.nb_gaussian * 3:]
67+
68+
self.b_s = x.shape[0]
69+
self.height = x.shape[2]
70+
self.width = x.shape[3]
71+
72+
e = self.height / self.width
73+
e1 = (1 - e) / 2
74+
e2 = e1 + e
75+
76+
mu_x = K.clip(mu_x, 0.25, 0.75)
77+
mu_y = K.clip(mu_y, 0.35, 0.65)
78+
79+
sigma_x = K.clip(sigma_x, 0.1, 0.9)
80+
sigma_y = K.clip(sigma_y, 0.2, 0.8)
81+
82+
x_t = T.dot(
83+
T.ones((self.height, 1)),
84+
self._linspace(0, 1.0, self.width).dimshuffle('x', 0))
85+
y_t = T.dot(
86+
self._linspace(e1, e2, self.height).dimshuffle(0, 'x'),
87+
T.ones((1, self.width)))
88+
89+
x_t = K.repeat_elements(
90+
K.expand_dims(x_t, dim=-1), self.nb_gaussian, axis=-1)
91+
y_t = K.repeat_elements(
92+
K.expand_dims(y_t, dim=-1), self.nb_gaussian, axis=-1)
93+
94+
gaussian = 1 / (2 * np.pi * sigma_x * sigma_y + K.epsilon()) * \
95+
T.exp(-((x_t - mu_x) ** 2 / (2 * sigma_x ** 2 + K.epsilon()) +
96+
(y_t - mu_y) ** 2 / (2 * sigma_y ** 2 + K.epsilon())))
97+
98+
gaussian = K.permute_dimensions(gaussian, (2, 0, 1))
99+
max_gauss = K.repeat_elements(
100+
K.expand_dims(
101+
K.repeat_elements(
102+
K.expand_dims(K.max(K.max(gaussian, axis=1), axis=1)),
103+
self.height,
104+
axis=-1)),
105+
self.width,
106+
axis=-1)
107+
gaussian = gaussian / max_gauss
108+
109+
output = K.repeat_elements(
110+
K.expand_dims(gaussian, dim=0), self.b_s, axis=0)
111+
112+
return output
113+
114+
@staticmethod
115+
def _linspace(start, stop, num):
116+
# produces results identical to:
117+
# np.linspace(start, stop, num)
118+
start = T.cast(start, floatX)
119+
stop = T.cast(stop, floatX)
120+
num = T.cast(num, floatX)
121+
step = (stop - start) / (num - 1)
122+
return T.arange(num, dtype=floatX) * step + start
123+
124+
def get_config(self):
125+
config = {
126+
'nb_gaussian':
127+
self.nb_gaussian,
128+
'init':
129+
self.init.__name__,
130+
'W_regularizer':
131+
self.W_regularizer.get_config() if self.W_regularizer else None,
132+
'activity_regularizer':
133+
self.activity_regularizer.get_config()
134+
if self.activity_regularizer else None,
135+
'W_constraint':
136+
self.W_constraint.get_config() if self.W_constraint else None,
137+
}
138+
base_config = super(LearningPrior, self).get_config()
139+
return dict(list(base_config.items()) + list(config.items()))

0 commit comments

Comments
 (0)